Author: Cameron Simoleit 

Supervisors: Azade Farshad, Yeganeh, Y. M. 

1. Introduction

Based on the Cambridge Dictionary, causality describes "the principle that there is a cause for everything that happens" (causality, 2024). 

To better understand what we hope to gain from the usage of causal generative models, we will have a look at some real word problem examples before diving into the causal models themselves.

1.1. Prompting a current generative model to answer questions about an image

Consider the top questions in yellow and blue posed in Image 1 which we want to prompt any modern Multi-modal Language Model (MLM) with. "How many cats are in the image?" or "Does this animal have claws?" can be answered with an understanding of what is in the image, and most current models perform reasonably well on such questions (Chen et al., 2023). But what happens when we ask questions which require more reasoning? Examples of these can be seen in the bottom questions for each image. Asking how many cats are in the image if the TV is off or if the animals have claws in the case that they were cats, not only changes the correct answer but also requires a deeper understanding of the image content and the relationships of the objects contained in it (Zhang et al., 2024).

Putting this into numbers, we can see (Image 2) that even models which most would consider among the "strongest" language models available to the public, suffer a significant performance drop when prompted with these types of questions. "Numerical Direct", "Numerical Indirect" and "Boolean" describe how the original question was changed to create the corresponding counterfactual question:

  • Numerical Direct changes the correct answer directly. For example "How many X would there be if two X were added?".
  • Numerical Indirect changes the correct answers indirectly. For example "How many cats would there be if the TV was off?".
  • Boolean questions oftentimes reverse the correct anwer as well as the question to create counterfactual presuppositions: "Would the cat be asleep if it was woken up?".

1.1.1. Three layer causal hierarchy

In a more abstract sense, we can visualize the reasoning complexity of different questions in a three layer causal hierarchy as introduced by J. Perl, with each level building on the one before and getting closer to what we would consider human level reasoning (Pearl, 2019; Bareinboim et al., 2022). The goal of introducing causality into our models in this case, is to reduce the performance drop we just saw when prompting them with these higher level questions.  

Level

Typical ActivityTypical QuestionsExamples

1. Association

SeeingWhat is?What does a symptom tell me about a disease?
2. InterventionDoingWhat if?What if I take aspirin, will my headache be cured?
3. CounterfactualsImaginingWhy?Was it the aspirin that stopped my headache?

1.2. Reusing our models when the domain changes

Next consider a motion forecasting problem in which we want to predict the future trajectories of people in different environments. The difficulty of this problem stems from different behavior of humans in different environments, e.g. a busy street compared to an evacuation situation. The complexity is in contrast to other motion forecasting problems, like ones concerning vehicles because they tend to, at least usually, abide by traffic laws. This difference in environments or domains is something even current models (Gupta et al., 2018; Kosaraju et al., 2019; Mangalam et al., 2020) struggle adapting to. This is because they rely on statistical inference and therefore lack explainability and are susceptible to learn spurious correlations (Bagi et al., 2023).

We can separate the variables in such a system into invariant and variant variables. Invariant variables, like physical laws or the maximum speed humans can travel at, that don't change between environments and therefore can be learned in one and be applied to all others. Variant variables on the other hand are specific to each environment. Using a causal approach we hope to separate these from each other to improve model performance when switching between domains (Bagi et al., 2023). 

people walking on street during daytime

1.3. Problems with current generative models

Recapping current problems in generative models that we hope to address with causal approaches:

  • Models lack a deeper understanding of the data they are trained with
  • Performance degrades when the domain in the data shifts 
  • The data that is used for training oftentimes contains selection biases and confounding factors, which are then inherited by the models during training

→ The models currently learn statistical correlations in the data without understanding causal relationships

"[...] we must go beyond learning mere statistical correlations toward causal models that capture the causal-effect relationship between influential variables in a system." (Komanduri et al., 2024)

2. Basic concepts

Before we get into the Causal Generative Models, we will quickly repeat and go over some basic concepts so that everyone is on the same page when going into the specifics. Basic machine learning knowledge will be required, though. 

2.1. Variational Autoencoders (VAEs)

A quick recap of some Variational Autoencoders (VAEs) because two of the methods we will be looking at build on top of this concept:

VAEs, a visualisation can be seen in Image 3, are similar to regular Autoencoders and are constructed of an encoder and a decoder part. They are trained to minimize the reconstruction error on the input data in a self supervised fashion. The input is encoded as a distribution over the latent space, which enables us to later sample from this space to generate new data using the decoder (Kingma & Welling, 2022). 

These and other model architectures like Generative Adversarial Networks (GANs) and Diffusion Models have been at the forefront of machine learning for approximating complex data distributions. They perform well when used for generating new data on the distribution they were trained on, but for understanding how specific underlying factors influence the generative process, we must additionally capture the causal relationships between the variables (Komanduri et al., 2024). 


2.2. Causal Generative Models

Recent works can be roughly categorized into the fields seen in Image 6. "Causal representation learning (CRL) is concerned with learning semantically meaningful causally related latent variables and their causal structure from high dimensional data" (Komanduri et al., 2024). The subclasses thereby describe what types of data the model has access to during training, i.e. observational, counterfactual or interventional. These correspond to the same categories we used earlier in the causal hierarchy. Controllable counterfactual generation (CCG) on the other hand focuses on modeling known causal variables and mapping them to the observed data. This means that we assume to have access to causal variables beforehand, which makes the mapping convenient and efficient. The approaches are divided into subclasses, according to their model architecture (Komanduri et al., 2024). Notice that we are still using the same general architectures as described in the previous segment.

 

2.3. Structural Causal Models

Recent research on the topic of causal models focuses on Perl's Structural Causal Model (SCM) which formally describes the relationship between a set of variables (Komanduri et al., 2024; Pearl, 2009). SCMs are represented by Directed Acyclic Graphs (DAG) which contain causal features and their interactions which each other. Consider X being a direct cause Y, then there will be a directed edge from X to Y in the DAG. The cause effect relation modeled here means that changing X can have an effect on the value of Y, while the reverse does not apply. The diagram below shows an example of how this could look like with three variables.

3. Dealing with Counterfactuals: Controllable Counterfactual Generation (CCG)

Due to the large number of different approaches for causal generative modeling, even in the limited time research has focused on this topic and to not go beyond the scope of this post, we will review two specific methods, thereby addressing the problems we stated earlier. These will offer a more detailed look at specific motivations, methodologies and results. For additional information, feel free to take a closer look at the references linked below. Let's address the counterfactual data first.

The idea behind controllable counterfactual generation is to train models to generate counterfactual images like taking an MRI image of a young healthy brain and turning it into the image of an old brain, changing the sex or switching from T1-weighted to T2-weighted. By generating these images, the model has to learn to understand what constitutes the specific differences we are looking for. The generation process also enables an easy visualization of what understanding the model currently has of the attributes we are applying (Ribeiro et al., 2023).

3.1. Further Motivation for Controllable Counterfactual Generation

Here is some futher motivation for controllable counterfactiual generation specific to the medical domain:

  • Data is scare in medical imaging, specifically for subgroups and rare pathologies
    → Controllable counterfactual generation can be used for data augmentation and the teaching of medical professionals
  • Better model explainability and fairness through the visualization of a models conception of attributes

3.2. Application: High Fidelity Image Counterfactuals with Probabilistic Causal Models

This section will focus on the work of Ribeiro et al., 2023.

3.2.1. Methodology

For the choice of model, this paper uses a Hierarchical Latent Variable Model (HLVR) which is a generative model that uses a prior over L layers of latent variables, thereby creating a hierarchical latent space. Using this approach in combination with a Variational Autoencoder (VAE) we end up with a Hierarchical Variational Autoencoder (HVAE) (Kingma et al., 2017; Sønderby et al., 2016; Burda et al., 2016) which has a hierarchical latent space over L layers.

For the generator network (the decoder in VAEs) two different approaches are considered:

Conditional HVAE with an exogenous prior

  • Uses labels in the generator for a conditional generation of images
  • Decouples prior from the labels 

Hierarchical Latent Mediator Model

  • Latent code shifts from being exogenous noise to a mediator role
  • If counterfactual data is available the mediator can be inferred, otherwise it must be approximated

Because we are in the realm of Controllable Counterfactual Generation (CCG), we assume a Structural Causal Model (SCM) for these approaches, which can be seen in Image 7.

In this case a,b,m,s,v represent the age, brain volume, MRI sequence, sex and ventricle volume. U_x is the exogenous noise for xz is the latent space and \epsilon a dirac delta distribution with no learned parameters of its own.

3.2.2. Results

Image 8 shows counterfactual generation examples based on the observational image, which can be seen on the top right. The image also visualizes the direct effect and uncertainty of each generated image. The second moment of the counterfactual distribution is used as a measure of counterfactual uncertainty.

To estimate the effectiveness of the generated counterfactual images, a model is trained using the original observational dataset and compared to a version which was fine-tuned with the generated images. The absolute difference of the error between these models can be used as a measure for the effectiveness of the additional generated data. As seen in Image 9 this approach results in a performance improvement for almost all considered labels in the dataset. The metrics considered are Area Under Curve (AUC) and Mean Absolute Error (MAE). 

4. Dealing with domain shift: Causal Representation Learning (CRL)

Now let's look at our second problem, reusing our models when the domain of our data changes.

4.1. Application: Generative Causal Representation Learning for Out-of-Distribution Motion Forecasting

This section will focus on the work of Bagi et al., 2023.

4.1.1. Motivation

Human trajectory prediction can help us improve areas like infrastructure design, intelligent transport systems and planing for evacuation situations (Kothari et al., 2021). Due to the sometimes drastically different behavior of humans in different environments, this problem poses a greater challenge than the trajectory prediction of vehicles, which usually abide by traffic laws. Most models proposed in literature rely on statistical inference which, as touched on earlier, have shortcomings in the areas of explainability, domain shifts and noise in the data (Liu et al., 2022). The main goal of this paper was to leverage causality to facilitate knowledge transfer under domain shifts.

4.1.2. Methodology

Image 10 depicts the general concept of the proposed causal model. Thereby X, Y represent past and future trajectories, invariant and variant variables are referred to as Z and S respectively. The selection variable E enables a selection of a specific environment for the variant variables.

 

The proposed model architecture is depicted in Image 11. It receives a time series of past trajectories in x and y coordinates as the input and uses them in two encoder and interaction module blocks. This block can be implemented as any type of neural network which can handle sequence-to-sequence temporal data, such as Recurrent Neural Networks (RNNs) or temporal Convolutional Neural Networks (CNNs). After going through an additional fully connected neural network, we receive a distribution for Z and S. These are then combined with decoders and more fully connected neural networks to finally generate our future trajectories and a reconstructed version of the original input. Additionally, we create rich priors from the simple priors, like normal distributions, using the coupling layers (Dinh et al., 2017). 


4.1.3. Results

With \alpha depicting the noise level in the data, Image 12 shows that the GCRL approach is more robust to noise than the baseline models, with the performance not changing for increased noise. For the lowest noise level however, the proposed method performs second to worst. The used evaluation metrics are Average Displacement Error (ADE) and Final Displacement Error (FDE) (Pellegrini et al., 2009; Alahi et al., 2016).

 The adaption speed to a new domain, as visualized in image 13, shows that the approach fine-tunes quicker (uses fewer batches) while reaching a lower Average Displacement Error (ADE), than the selected benchmark IM (Liu et al., 2022).



5. Thinking outside the box: Causal-Effect Look at Context Generation for Boosting​ Multi-modal Language Mode

This approach is mentioned separately because it is different to the others in the fact that we are not training a new model here, but improving the performance of an existing model without the need for any retraining or fine-tuning. The idea behind this section is to encourage you to think outside the box when incorporating causality and show a fundamentally different approach, which still applies to the realm of generative models. We will be focusing on the work of Zhao et al., 2023.

Consider a Vision Question Answering (VQA) task like the one depicted in Image 14. The Multi-modal Language Model (MLM) has to answer a question about the image and receives additional context which should help it answer the question. The idea here is that the context was inferred from the same language that is ultimately answering the question. 

Because the context is basically just a description of the image (this is also what we ask the model to generate), its utility depends on the question we want answering. As seen in Image 15 a descriptive context can be beneficial for answering a question like "What is in the image?" but adds limited value to "Is the actor inside the red bounding box named Frank Morgan". This is the point where we use causality as a filter by determining the causal effect the context has on our final answer. Depending on the effect the context has, we can decide whether to use the inference with our without it. 

We use the inference including the context if the Total Indirect Effect (TIE) is larger than the Natural Direct Effect (NDE). These are calculated as follows:

\begin{split} \text{Total Effect (TE)} &= \mathbb{E}[Y(I,C,Q)-Y(Q)] \\ \text{Natural Direct Effect (NDE)} &= \mathbb{E}[Y(I,Q)-Y(Q)] \\ \text{Total Indirect Effect (TIE)} &= \text{TE} - \text{NDE} \end{split}

 \mathbb{E}[\cdot] represents the expectation operation and Y represents the answer obtained in the VQA task. The subtraction signifies a comparison between two types of outcomes. This comparison is instantiated using the Jensen-Shannon Divergence (JSD) to compare the difference between the two distributions. Image 16 visualizes the relationship between the different effects we calculated (* denotes that this input was not used in the MLM when doing VQA). 

 

5.1.1. Results

Image 17 shows the performance difference not only between the original LLM and the proposed method, but also between applying the context in a naive fashion and utilizing the causal filtering approach. We can see that adding the context to the inference in VQA improves the performance of the model in most of the benchmarks but falls behind the approach which combines the context with the causal filter. Ensemble and One-shot were conducted on LLaVA (Liu et al., 2023) directly. The ensemble values were created by averaging the likelihood distribution of 5 different prompts. One-shot uses one-shot in-context learning on LLaVA, similar to Flamingo (Alayrac et al., 2022).

6. Conclusion and discussion

The areas that causal models are trying to improve are among the most important shortcomings that current machine learning solutions still suffer from. When training our models, we need to consider biases they might inherit from the data they are trained on or performance drops the moment they are confronted with out-of-distribution information. Forcing these models to develop a deeper understanding of the data they are learning from and moving away from strictly statistical representations will enable us to develop more robust generative approaches, which through their interpretability are also fairer and less susceptible to biases. The research is still very novel in this area, with multiple different approaches having emerged which incorporate causality into the training or inference process in different ways, as shown in this post. The current areas of research however are definitely not limited to the approaches shown here. The methods we looked at all assumed a predefined SCM, but research looking into learning these models from data are definitely also worth a look (Poinsot et al., 2024; Ryan et al., 2022), but are outside the scope of this post. I could imagine that in the coming years the multitude of approaches will condense to a few selected ones which are proven to work best, similar to the usage of the Transformer architecture for language models. In the end, causal models aim to enforce a behavior in our models which they were built for in the first place, to capture the causal-effect relationship between influential variables in a system. 

However, many approaches in academic literature make a markovian assumption on the data they use (Ribeiro et al., 2023), meaning that all causal effects must be identifiable from the observed data and there are no unobserved confounding factors . This is one area that future research should focus on. While causal generative models are not yet the silver bullet which will solve all our problems, they are a step in the right direction of more mature machine learning approaches. 

7. References

Alahi, A., Goel, K., Ramanathan, V., Robicquet, A., Fei-Fei, L., & Savarese, S. (2016). Social LSTM: Human Trajectory Prediction in Crowded Spaces. 961–971. https://openaccess.thecvf.com/content_cvpr_2016/html/Alahi_Social_LSTM_Human_CVPR_2016_paper.html

Alayrac, J.-B., Donahue, J., Luc, P., Miech, A., Barr, I., Hasson, Y., Lenc, K., Mensch, A., Millican, K., Reynolds, M., Ring, R., Rutherford, E., Cabi, S., Han, T., Gong, Z., Samangooei, S., Monteiro, M., Menick, J., Borgeaud, S., … Simonyan, K. (2022). Flamingo: A Visual Language Model for Few-Shot Learning (arXiv:2204.14198). arXiv. https://doi.org/10.48550/arXiv.2204.14198

Bagi, S. S. G., Gharaee, Z., Schulte, O., & Crowley, M. (2023). Generative Causal Representation Learning for Out-of-Distribution Motion Forecasting (arXiv:2302.08635). arXiv. https://doi.org/10.48550/arXiv.2302.08635

Bareinboim, E., Correa, J. D., Ibeling, D., & Icard, T. (2022). On Pearl’s Hierarchy and the Foundations of Causal Inference. In Probabilistic and Causal Inference: The Works of Judea Pearl (1st ed., Vol. 36, pp. 507–556). Association for Computing Machinery. https://doi.org/10.1145/3501714.3501743

Building Evacuation Plans (Faculty and Staff). (n.d.). Faculty and Staff (Penn State College of Agricultural Sciences). Retrieved June 12, 2024, from https://agsci.psu.edu/faculty-staff/safety/procedures/evacuation-plans

Burda, Y., Grosse, R., & Salakhutdinov, R. (2016). Importance Weighted Autoencoders (arXiv:1509.00519). arXiv. https://doi.org/10.48550/arXiv.1509.00519

Causality. (2024, June 5). https://dictionary.cambridge.org/de/worterbuch/englisch/causality

Chen, J., Zhu, D., Shen, X., Li, X., Liu, Z., Zhang, P., Krishnamoorthi, R., Chandra, V., Xiong, Y., & Elhoseiny, M. (2023). MiniGPT-v2: Large language model as a unified interface for vision-language multi-task learning (arXiv:2310.09478). arXiv. https://doi.org/10.48550/arXiv.2310.09478

Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2017). Density estimation using Real NVP (arXiv:1605.08803). arXiv. https://doi.org/10.48550/arXiv.1605.08803

Gupta, A., Johnson, J., Fei-Fei, L., Savarese, S., & Alahi, A. (2018). Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks (arXiv:1803.10892). arXiv. https://doi.org/10.48550/arXiv.1803.10892

Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models (arXiv:2006.11239). arXiv. http://arxiv.org/abs/2006.11239

Kingma, D. P., Salimans, T., Jozefowicz, R., Chen, X., Sutskever, I., & Welling, M. (2017). Improving Variational Inference with Inverse Autoregressive Flow (arXiv:1606.04934). arXiv. https://doi.org/10.48550/arXiv.1606.04934

Kingma, D. P., & Welling, M. (2022). Auto-Encoding Variational Bayes (arXiv:1312.6114). arXiv. https://doi.org/10.48550/arXiv.1312.6114

Komanduri, A., Wu, X., Wu, Y., & Chen, F. (2024). From Identifiable Causal Representations to Controllable Counterfactual Generation: A Survey on Causal Generative Modeling (arXiv:2310.11011). arXiv. https://doi.org/10.48550/arXiv.2310.11011

Kosaraju, V., Sadeghian, A., Martín-Martín, R., Reid, I., Rezatofighi, S. H., & Savarese, S. (2019). Social-BiGAT: Multimodal Trajectory Forecasting using Bicycle-GAN and Graph Attention Networks (arXiv:1907.03395). arXiv. https://doi.org/10.48550/arXiv.1907.03395

Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning (arXiv:2304.08485). arXiv. https://doi.org/10.48550/arXiv.2304.08485

Liu, Y., Cadei, R., Schweizer, J., Bahmani, S., & Alahi, A. (2022). Towards Robust and Adaptive Motion Forecasting: A Causal Representation Perspective (arXiv:2111.14820). arXiv. https://doi.org/10.48550/arXiv.2111.14820

Mangalam, K., Girase, H., Agarwal, S., Lee, K.-H., Adeli, E., Malik, J., & Gaidon, A. (2020). It Is Not the Journey but the Destination: Endpoint Conditioned Trajectory Prediction (arXiv:2004.02025). arXiv. https://doi.org/10.48550/arXiv.2004.02025

Overview of GAN Structure | Machine Learning. (n.d.). Google for Developers. Retrieved June 7, 2024, from https://developers.google.com/machine-learning/gan/gan_structure

Pearl, J. (2009). Causality: Models, Reasoning and Inference: Models, Reasoning and Inference. Ausgezeichnet: ACM Turing Award for Transforming Artificial Intelligence 2011 (2nd ed.). Cambridge University Pr.

Pearl, J. (2019). The seven tools of causal inference, with reflections on machine learning. Communications of the ACM, 62(3), 54–60. https://doi.org/10.1145/3241036

Pellegrini, S., Ess, A., Schindler, K., & van Gool, L. (2009). You’ll never walk alone: Modeling social behavior for multi-target tracking. 2009 IEEE 12th International Conference on Computer Vision, 261–268. https://doi.org/10.1109/ICCV.2009.5459260

Poinsot, A., Leite, A., Chesneau, N., Sébag, M., & Schoenauer, M. (2024). Learning Structural Causal Models through Deep Generative Models: Methods, Guarantees, and Challenges (arXiv:2405.05025). arXiv. https://doi.org/10.48550/arXiv.2405.05025

Ribeiro, F. D. S., Xia, T., Monteiro, M., Pawlowski, N., & Glocker, B. (2023). High Fidelity Image Counterfactuals with Probabilistic Causal Models (arXiv:2306.15764). arXiv. https://doi.org/10.48550/arXiv.2306.15764

Ryan, O., Bringmann, L. F., & Schuurman, N. K. (2022). The Challenge of Generating Causal Hypotheses Using Network Models. Structural Equation Modeling: A Multidisciplinary Journal, 29(6), 953–970. https://doi.org/10.1080/10705511.2022.2056039

Sønderby, C. K., Raiko, T., Maaløe, L., Sønderby, S. K., & Winther, O. (2016). Ladder Variational Autoencoders (arXiv:1602.02282). arXiv. https://doi.org/10.48550/arXiv.1602.02282

Unsplash. (2021, March 26). Photo by Pascal Bernardon on Unsplash. https://unsplash.com/photos/people-walking-on-street-during-daytime-3jlJjmMUjEA

Variational autoencoder. (2024). In Wikipedia. https://en.wikipedia.org/w/index.php?title=Variational_autoencoder&oldid=1226413743

Zhang, L., Zhai, X., Zhao, Z., Zong, Y., Wen, X., & Zhao, B. (2024). What If the TV Was Off? Examining Counterfactual Reasoning Abilities of Multi-modal Language Models (arXiv:2310.06627). arXiv. https://doi.org/10.48550/arXiv.2310.06627

Zhao, S., Li, Z., Lu, Y., Yuille, A., & Wang, Y. (2023). Causal-CoG: A Causal-Effect Look at Context Generation for Boosting Multi-modal Language Models (arXiv:2312.06685). arXiv. https://doi.org/10.48550/arXiv.2312.06685


  • Keine Stichwörter