Blog post written by: Simon Bohnen

Based on: Liu, M., Lee, C.-W., Sun, X., Yu, X., Qiao, Y., & Wang, Y. (2025). Learning causal alignment for reliable disease diagnosis. In Proceedings of the Thirteenth International Conference on Learning Representations.

What is Causal Alignment?

To understand what Causal Alignment is, I want to tell you a little story:

You are taking a walk through the forest with your 4 year old son Paul. Paul loves animals, so you point out all the birds you see in the forest. There's sparrows, ravens, magpies, and many, many other birds, and the two of you take a good look at all of them. You get to a clearing and stop before a tall tree. A squirrel is sitting on one of its branches and when Paul spots it, he says: "Look dad, a bird!"

How could that have happened? When you were looking at the magpies and ravens, Paul knew that they all were birds. But maybe you had different reasons why you thought they are birds: you looked at their feathers and their beaks, but Paul just looked for the tail. And when he saw the squirrel's tail, he concluded that the squirrel must be a bird too. This is an example of Causal Misalignment: both of you are able to recognize birds, but for different reasons. And when you see a new animal, you might classify it differently than Paul does.

All of these have tails, but not all of these are birds. (Images from Unsplash)

What we want is Causal Alignment: when you and Paul are recognizing birds, you should both pay attention to beaks and feathers, not just the tail. That way, you can be confident that Paul won't mistake a squirrel for a bird later on.

Causal Alignment is also important when making medical diagnoses. When a medical student learns to recognize cancerous tissue on a CT scan, it's crucial that the student pays attention to the same details of the image as the teacher. Otherwise the student will make wrong diagnoses later on. Similarly, medical machine learning models should pay attention to the same details of an image which a doctor is paying attention to. But how can we enforce causal alignment for medical ML models?

What data are we working with?

Before we dive deeper into enforcing causal alignment, we have to get to know the data we're working with. Our images are slices of lung CT scans which might or might not contain malignant lung nodules.

The components of our data:

  • x: The image
  • y: The image's label. y=1 means there is a malignant nodule, y=0 means there is no malignant nodule.
  • m: The doctor-annotated area where the nodule is. Annotated as a red box in the image.

We are trying to construct a function f which looks at our image x and determines the correct label y. To ensure causal alignment, we also want to make sure that f is "paying attention" to the abnormal area m, and not some other part of the image.

How can we enforce Causal Alignment?

To understand how we can enforce Causal Alignment, I want to take you back to our example with Paul. Paul is fascinated by numbers, so you want to teach him how to read basic numbers:

You: What is this number Paul?

Paul: I think that's a nine right?

You: Exactly! How would change it to make a different number?

Paul: I would remove the circle at the top. Then it's a three!

You: But that's not even a number anymore! If you want to change it to a three, you have to remove the line on the left.

You: Let's try another number. Which one is this?

Paul: That's a four!

You: Exactly! And how can you change it to a nine?

Paul: I can add a line at the top to make it a nine!

You: That's correct, you got it!

Paul is now not only able to read the numbers correctly, but he also understands how each line contributes to a given digit! This makes us confident that Paul will be able to read all kinds of numbers in the future.

Let's translate the dialog to our medical example:

Training our model

Dialog with Paul

Get the prediction for the image x by computing f(x).

You: What is this number Paul?

The model f outputs y=0.8. The model "thinks" that x contains a malignant nodule, but is not entirely sure.

Paul: I think that's a nine right?

We adapt f and make it even more confident that x contains a malignant nodule.

You: Exactly!

We modify our image x such that we get a counterfactual image x^* which f classifies as not having a malignant nodule.

You: How would change it to make a different number?


Paul: I would remove the circle at the top. Then it's a three!

Both Paul and the model f have made nonsensical modifications to the original image. They changed way too much! As you have lots of experience with numbers, you tell Paul that to change the nine to a three, he only has to remove the line on the top left. Similarly, we can use the doctor's experience to teach the model where the malignant nodule is in our image! To change the image x into one that does not contain a lung malignant nodule, only the doctor-annotated area m has to be changed. Thus we modify f such that x^* and x only differ within m:

As we can see in the rest of the dialog, training our model in this way helps it to pay attention to the correct area of each image:

Training our model f

Dialog with Paul

We input another image x into f:

You: Let's try another number. Which one is this?

The model outputs y=1, which is correct.

Paul: That's a four!

We generate x^*, which only requires slight modification within the area m (see the bottom right of the image). This indicates that f "understands" where in the image the malignant nodule is present.


You: Exactly! And how can you change it to a nine?

Paul: I can add a line at the top to make it a nine!

You: That's correct, you got it!

As you can see, we are doing two things during our training process:

  • We are forcing the output of f to be closer to the correct label y
  • We are modifying f such that the differences between two images x (which f classifies as y=1) and x^*(which f classifies as y=0) are within the doctor-annotated area m.

This not only ensures that the model classifies the images correctly, but also that it pays attention to the right part of the image. We can see this here:

Thus we've managed to enforce causal alignment in our classifier f.

Other approaches to ensure alignment

Before I show you how this approach performs in comparison to other methods, I want to explain the other methods briefly.

Ross et al.: Right for the Right Reasons: Training Differentiable Models by Constraining their Explanations [1]

Ross et al. take a similar approach to Liu et al., with slight mathematical differences. Let's consider a classifier which can detect if a dog is present in an image:

We can see that the classifier is mainly looking at the snout of the dog. But suppose we don't want the classifier to look at the dog's snout! Then we directly prevent it from making changes to its output based on the snout:

This will force the classifier to find some other feature which dogs have in common, e.g. their tail.

Zhang et al.: Interpretable Convolutional Neural Networks [2]

Zhang et al. take a different approach. Our classifier f consists of different layers l_1, l_2, ... and their approach forces each of these layers to pay attention to a continuous, semantically meaningful area:

How do the different methods compare?

The Dataset

Liu et al. compare their method to the other methods on two datasets. We will focus on the results based on the LIDC dataset which contains annotated CT scans of lung nodules. To test that their method truly enforces causal alignment they add a pseudo-feature to the images:

All the images which contain a malignant nodule get a "+" in the top-left corner. All the images which do not contain a malignant nodule get a "-". A simple model could just look at the top left corner and get the correct answer every single time! But that would have two problems:

  • Low alignment score: The simple model would pay attention to the wrong area (the nodule is almost never in the top-left corner)
  • Low classification accuracy: The simple model would not be able to classify any images which do not contain the +/- label

Liu et al.'s method should be able to avoid these two problems:

  • High alignment score: The model should still pay attention to the nodule, not the +/- label
  • High classification accuracy: The model should be able to classify images without the +/- label. Thus, we only provide the +/- label in the training set, not in the test set.

As these two criteria are both important for the model's overall usefulness, the authors report on both. We will look at the second metric first: the classification accuracy.

Classification Accuracy: How well can the different methods classify the lung nodule images?

Liu et al.'s method performs best: it is able to classify more than 70% of images correctly. The other approaches have varying accuracies: Ross et al. comes close with 65%, the others range from 35 - 50%. An accuracy below 50% is worse than random chance, which indicates that the +/- label "confuses" the other methods during the training process.

Alignment Score: Are the methods able to pay attention to the correct part of the image?

To measure the degree of alignment of the different methods, the researchers measure the fraction of the doctor-annotated area the model pays attention to.

In red: The doctor-annotated area.
In blue: The model's area of attention.
The researchers measure the black area divided by the blue area. In this example, the alignment score would be 25%. When we look at the results, we can see that the alignment scores of the different methods vary greatly:

Liu et al.'s method performs very well: On average, their model's attention area has a 75% overlap with the doctor-annotated area. No other method makes it past 10%! We can see that the +/- label massively distracts the other methods from paying attention to the correct areas of the image. We can see this in the following visualization:

To compare the performance in a real world scenario it would have been interesting to see the models' alignment scores when trained on images not containing the +/- label. This experiment is not provided by the authors unfortunately.

What can we learn from this?

We saw that there is great value in not only teaching a model or a child what it sees, but also to pay attention to the right parts of an image. When a doctor and an ML model pay attention to the same area of an image to classify it, we have causal alignment. We can enforce causal alignment by generating a counterfactual image x^* which is classified differently by the model f, and then modifying f such that x and x^* differ only within the doctor-annotated area m.

Even if there is some pseudo-feature in the image, Liu et al.'s method only pays attention to the doctor-annotated area. Other methods struggle to stick to the doctor-annotated area and classify only about 50% of the images accurately. Unfortunately it remains unclear how Liu et al.'s method compares to other methods when there is no artificial pseudo-feature present.

Their method is easily applicable to other imaging modalities, e.g. mammography scans. However, their method only works when the label y is binary. Using their method to predict more complex target variables would require additional modifications.

Glossary

Machine Learning (ML): A group of techniques to construct a function f which takes in some data x (e.g. an image) and predicts a target variable y (e.g. whether the image contains a malignant lung nodule). The resulting function f is called a model.

Computed Tomography (CT): A technique to create a 3D scan of the inner parts of an organism using x-rays. Used to diagnose e.g. certain types of cancer.

Causal Alignment: Causal alignment is present when the reasons which two decision procedures use to arrive at their decisions are similar.

Lung Nodule: A tumor within the lung that might or might not be malignant.

Classification Accuracy: The percentage of images that a model f is able to classify correctly.

Alignment Score: The proportion of the ground truth area which the model pays attention to when classifying.

References

[1] Ross, A. S., Hughes, M. C., & Doshi-Velez, F. (2017). Right for the right reasons: Training differentiable models by constraining their explanations

[2] Zhang, Q., Wu, Y. N., & Zhu, S.-C. (2018). Interpretable convolutional neural networks

  • Keine Stichwörter