Blogpost written by: Philipp Pindl
Based on: De Sousa Ribeiro, F., Xia, T., Monteiro, M., Pawlowski, N. & Glocker, B. (2023). High Fidelity Image Counterfactuals with Probabilistic Causal Models.
In Proceedings of Machine Learning Research, PMLR 202:7390-7425, 2023. https://doi.org/10.48550/arXiv.2306.15764
1. What are Counterfactuals, and why should we care?
Picture the following situation: You have an important exam on Friday, but your friends convince you to party instead of studying. After taking and failing the exam on Friday, you wonder: "Would I have passed had I stayed home and studied, or would I have failed anyway"This question is an example of a counterfactual. A counterfactual explores a hypothetical scenario, asking what could have happened if a specific part of a situation were different. In this case the Factual scenario (partying and failing the exam) did happen, but we are interested in the outcome of a hypothetical scenario, the Counterfactual. An important concept is that everything is assumed to remain exactly the same, except for the change we introduce, the Intervention. Here, the intervention is staying home instead of partying.
Beyond everyday decisions, counterfactual reasoning has important applications in fields like medicine. In the paper, we explore this is done for the example of MRI scan of the brain. Here the MRI scans together with extra information like age, sex and ventricle volume are used as factual observations. It is then possible to change one or multiple of variables while still keeping the identity of the scan intact
Below is an example of increasing the ventricle volume in an MRI scan. Notice how the changes are localized to the ventricles in the center of the brain, while the rest of the image remains mostly unaffected. Using these counterfactuals could for example aid doctors in which factors affect certain parts of the scan and therefore take more effective measures. An interactive demo, where you can generate these counterfactuals yourself, can be found here.
| Observation | Counterfactual | Causal Effect |
Constructing these counterfactuals is challenging because it’s impossible to recreate the exact same situation with only one specific change. However, with the help of a model that captures the causal dependencies of the environment, it becomes possible to generate counterfactuals. In the next section, we introduce a framework that enables this process.
2. The Framework for Counterfactuals: Structural Causal Models
2.1. What are Structural Causal Models
A Structural Causal Model (SCM) is a framework used to represent and understand cause-and-effect relationships between variables. It allows us to model not only the direct effects between variables, but also the unobserved factors - often referred to as "noise" - that can influence outcomes. For instance, studying on Thursday does have a causal effect on whether you pass the exam on Friday. However, there are also unobserved factors, such as your level of anxiety or distractions, which can also affect the outcome, even though they aren't directly observable in the data.
In order to build an accurate model of the environment, an SCM must also account for these factors outside the system, called exogenous causes, which alsoinfluence the outcome.
Imagine you are enrolled in a seminar course that requires you to complete a project. The process is as follows:
- You are assigned a Topic
- You put in some amount of Effort to complete the project
- You are then Graded on your work
So the variables we observe are:
X = \{\text{Topic}, \text{Effort}, \text{Grade}\}
On the right there's a diagram which represents this SCM graphically. Each arrow denotes a causal relationship between two variables. Additionally, each endogenous Variable x has an associated noise U_x variable (in dotted circles) which models unobserved effects. In this model, these are:
U= \{U_T,U_E,U_G\}
For example U_E could encapsulate how your current mood is during the work on the project which therefore influences how motivated you are to work.
The final component for the SCM are so-called structural equations, which quantify the causal effect between variables. In general, these always have the form:
x := f(pa_x, U_x)
, where pa_x denotes the parent variables of x and U_x is the noise variable corresponding to x. In our model, we assume these equations to be:
| \begin{align*} &\text{Grade} := U_G + \text{Topic} + \text{Effort}\\ &\text{Effort} := U_E\ \text{AND}\ \text{Topic}\\ &\text{Topic}:= U_T \end{align*} |
Note that for the sake of simplicity all variables in the model are binary except for Grade.
2.2. How to generate Counterfactuals
Now that we understand how SCMs are constructed, we can explore how they help us in generating Counterfactuals. To do this, we first need an observation from the real world. In this example we assume we have received a good topic (Topic = 1), have not put in a great deal of effort (Effort = 0) and received a Grade of 2. The question we now want to answer is:
How would our grade have changed, had we put in more effort?
To achieve this, we need to follow three main steps [1]:
Abduction
The first step is to infer the unobserved factors (noise variables) U given the observation x. More formally, we need to calculate the probability p(U|X)=\prod_i p(u_i|x_i)
This can be done using Bayes rule or other techniques in more complex cases. For our example, however, we can do this deterministically.
For instance, since we know that
0 = \text{Effort} := U_E\ \text{AND}\ \text{Topic} = 1\ \text{AND}\ U_E
, we can easily see that U_E=0. The other noise variables can be inferred likewise.
This step is important to ensure that the external factors on the model stay the same, when generating the counterfactual. In other words, we want to isolate the effect of the change we are making (in this case, increasing effort) without altering other factors that could affect the outcome.
Intervention
The second step is the intervention or action on the SCM. Here, we just need to change our variable(s) of interest to the desired value. In this case, we want to see how the grade would change if we had put in more effort. Therefore, we set:
\widetilde{\text{Effort}} := 1
The tilde (~) above a variable indicates that the variable belongs to the counterfactual scenario and not the original observation. By setting Effort to 1 in the counterfactual world, we also automatically remove the incoming arrows into Effort (i.e., we ignore its usual causal relationships with other variables). This is often written using the do operator as:
do(\text{Effort}=1)
The do operator ensures that we manipulate Effort directly, without considering how its usual causes (such as the Topic) would influence it.
Prediction
The final step is to make a prediction, or in this case, to calculate the counterfactual grade. To do this, we re-evaluate all the structural equations in the model using the counterfactual values of the variables.
In this step, we still use the noise variables U that we inferred in the abduction step, ensuring that nothing else changes in the system apart from the intervention we made. In other words, we ensure that the identity of the observation remains the same. Thus, the counterfactual assignments take the form:
x := f(\tilde{pa}_x, u_x)
Because we only have a distribution over the values of u, we obtain a distribution over x. In our example, we have exact noise values and can evaluate the equations deterministically. We see that the new Grade becomes:
\widetilde{\text{Grade}}= 1 + 1 + 1 = 3
This means that had we put in more effort, we would have, unsurprisingly, gotten a better grade.
3. Machine Learning Components in SCMs
Up until now, we have ignored a major issue with our approach. We just assumed that we know the structural equations governing our SCM. In practice, we do not have this information. Instead, we only have observations in the form of data points. Luckily, this is the exact problem Machine Learning solves: Fitting a model or an equation to large amounts of data. However, we cannot just use an arbitrary model, instead it must fulfill special properties. To explain this, we first review what our model needs to achieve:
- Firstly, we need to learn Relationships of the form x=f(pa_x, u_x). Here, crucially, the noise values u_x are unobserved. Usually we just assume that these are generated by a standard normal distribution.
- Secondly, we need to abduct the noise as p(u_x | x). For this, our model needs to be invertible in some sense.
From the first point, we can see that the model needs to be able to transform a distribution over u to a distribution over x, while being able to be conditioned on the parents. This is a setting that is also often assumed by generative models. They take in a sample from a simple (mostly Normal) distribution and transform it to a more complex distribution in the data space (such as images of cats). Some of these models also naturally exhibit a notion of invertibility. In the following, we present the two models which are used in the paper and discuss how they can be inverted for abduction [2].
3.1 Conditional Normalizing Flows
In some sense, Normalizing Flows (NF) are the simplest form of generative model. They are a series (or flow) of invertible transformations f_i(x) that are learned from the data. Conceptually, they take in a simple base distribution and iteratively turn it into a more complex one. This can be nicely seen in the figure below. Another upside of this approach that it provides an invertible mapping x=f(u) such that the noise can be abducted deterministically as u=f^{-1}(x).
However, with these desirable properties also come downsides. Normalizing flows are not great for generating very high dimensional images. They are computationally expensive for high dimensional data x\in \mathbb{R}^d spaces because, due to their invertibility, all functions live in the same dimension f_i: \mathbb{R}^d \rightarrow \mathbb{R}^d. Therefore, they become expensive to compute.
For this reason, normalizing flows are particularly suited to non-image dependencies. For images, a different architecture should be used.
It is also possible to implement conditioning in the Normalizing Flows to add the dependence on the parent values.
3.2 Variational Autoencoders
An approach that is much better suited for image data is a Variational Autoencoder. From this section onward x now always refers to an image (like an MRI scan), which is typically the variable of interest in the model. A Variational Autoencoder (VAE) is another generative model that can be used in SCMs. They are computationally tractable for image spaces due to being able to work in lower dimensional spaces internally. However, fitting this model into the SCM framework is a little more complicated. The VAE contains two separate sources of noise during image generation which need to addressed:
- Firstly, the z noise inside the latent space
- Secondly, \epsilon used to sample the final image from the decoder
To make the noise abduction in this setting tractable, we assume a noise factorization as p(u_x) = p(z) \cdot p(\epsilon). This allows us to abduct the likelihood sequentially. First, the trained variational distribution (the encoder) q(z|x, pa_x) \approx p(z|x, pa_x)can be used to approximately abduct z. We can then use the sampled z noise to generate the decoder \muand \sigma values to invert the reparametrization h(\epsilon, \mu, \sigma) and obtain the used \epsilon = h^{-1}(x, \mu, \sigma) noise. The below graphic also shows how this reparametrization step in more detail.
4. Image Generation Architectures
The previous section described the general approach in which VAEs can be incorporated into an SCM. The authors of the paper now discuss two specific instantiations of this model, which we will briefly cover in this section. These approaches mostly differ in the interpretation of the z variable in the context of the SCM. They also use a hierarchical variational autoencoder, with a hierarchy of z values to generate sharper images compared to a regular VAE. As a reminder, in the following graphics x still denotes the final image variable of our model.
4.1. Exogenous prior
The first approach is very similar to the general approach discussed previously. Here z takes the role of exogenous noise which is unobserved and not part of the model. Because of this, it cannot be directly affected by the parents of the image x. Unfortunately, this has the consequence that the true parameters of the underlying SCM may not be learned correctly by our VAE, which is undesirable.
4.2. Latent Mediator
In contrast to the previous model z now becomes an endogenous variable that is part of our model. This means that it can now be directly dependent on pa_x which makes it a Mediator between pa_x and x. However, the values for z still remain unobserved. In this model, conducting the abduction step is a little more complicated for the benefit of guaranteeing identifiability. In practice, the researchers found only a small performance difference between both models experimentally, despite this theoretical obstacle.
5. Experiments
To validate their approach, the researchers applied their models to multiple datasets and tried generating counterfactuals for each dataset. For each SCM used for the dataset, only the image generation is modeled by a VAE. Non image dependencies are always modeled by conditional normalizing flows. For each experiment, a SCM needs to be defined which encapsulates expert knowledge in the respective domain.
5.1 Brain MRI scans
The first dataset for generating counterfactuals are brain MRIs from the UK Biobank, as already shown previously. The variables we can adjust here are:
MRI Sequence, Age, Sex, Brain volume and Ventricle volume
The SCM for this data can be seen in the graphic below. The architecture shown is the exogenous prior variant from Section 4.1, because the z variable does not directly depend on the parents of the image and is part of the exogenous noise. Here, the age variable is of special interest, as it does not directly affect the image variable x.
For this reason, it is interesting to see that the intervention on the age variable still has an effect. Another interesting variable is the changing of the MRI sequence. This can be useful for generating a different flavor of MRI scan, without having to do a second scan. This may not be possible with a conventional supervised ML model because a specific dataset containing the same scan in multiple sequences may not be available. Counterfactuals for each intervention category can be found in the graphic below [3].
5.2 Chest X-rays
The second dataset consists of chest X-ray images from the MIMIC-CXR dataset. The adjustable variables here are:
Age, Disease, Race and Sex
The disease variable here is of special interest, since it can be used to simulate the presence or absence of a disease. The only disease which can be simulated is a pleural effusion. Again, the shown architecture in the SCM is the exogenous prior variant.
Counterfactuals from the trained model are summarized in the graphic below, with one example for each attribute [3]. We can see that each change has a distinct effect that is local to specific parts of the X-ray.
6. Summary
The paper showed that counterfactual generation with the SCM framework can be done on complex domains, such as medical images. In order to achieve this, many tweaks to the image architecture generation were made in order to produce high resolution images. Some of these were not discussed in detail here to keep things more clear. A future research direction could be the inclusion of continuous normalizing flows in the SCM, which can be trained with the recently prominent "Flow matching" objective [4]. These have shown promising results recently and may become a successor to diffusion models. Because of the straightforward inclusion of standard normalizing flows using continuous flows should also be conceptually simple.
Regarding the practical implications of this paper, it is my personal opinion that there needs to be a robust measure to evaluate the "truthfulness" of these counterfactuals before using them in practice. As of right now, the main way to validate the performance is to measure the difference between predictions from an attribute predictor (like an age predictor) and the intervened value. To me this seems like a bad choice of metric since it is not very reproducible and highly dependent on the choice of predictor.
Apart from this, the paper shows a very promising direction in causal machine learning and highlights the potential to generate high-dimensional counterfactuals that could aid doctors in decision-making in the future.
References
[1] Neal, B., Introduction to causal inference. https://www.bradyneal.com/causal-inference-course
[2] Pawlowski, Nick, Daniel Coelho de Castro, and Ben Glocker. "Deep structural causal models for tractable counterfactual inference." Advances in neural information processing systems 33 (2020): 857-869.
[3] Ribeiro, Fabio De Sousa, et al. "High fidelity image counterfactuals with probabilistic causal models." arXiv preprint arXiv:2306.15764 (2023)
[4] Lipman, Yaron, et al. "Flow matching for generative modeling." arXiv preprint arXiv:2210.02747 (2022)











