Introduction


Machine Learning models are increasingly applied in a multitude of high-impact areas to automate decision making processes such as loan applications or medicine. On the other hand, the decisions, that these systems make still lack transparency and are difficult to understand for users. In order to solve this problem, the research field of explainable AI (xAI) tries to develop methods, that explain these decisions. In the medical domain, where both patients and doctors need to trust the models as non-experts in AI, giving explanations is particularly important. While in xAI there exist certain models, that are designed to be inherently interpretable, most models can still be considered to be ”black-box” models, which need to be explained using a post-hoc approach [13]. Furthermore, we can distinguish between trying to find local explanations for each data point and global methods, which try to find certain mechanisms or biases for the whole model.

In order to adress the problem of lacking explainability Wachter et al. [15] proposed the idea of Counterfactual Explanations. These kind of explanations demonstrate, how the world would have to behave, in order to yield a different, desirable outcome. An example for this can be for example be found in medical diagostics: Let us assume a model predicts a specific treatment plan from blood statistics. Then we can ask ourselves: Which difference in blood statistics would be necessary for the model to predict a specific alternative treatment plan? In this example, medical professionals could understand why the model does not suggest an alternate medication, which results in enhanced trust in the decision. Counterfactional explanations can thus be seen as a local, post-hoc explanation method. 

In a second application, counterfactuals cannot only be used for explanations, but also for deriving saliency maps. These can be created by comparing the counterfacutal with the original and highlighting the perturbed areas. Fontanella et. al [5] use such saliency maps created by counterfactuals in order to train a classifier on Chest X-Rays. Similarly, Wang et. al [16] use symmetric priors to first create counterfactuals and then saliency maps for classification improvement.

While most papers in this area focus on image or tabular data, in the last years these models have also been applied to explain text data, time series data or recommendation systems [14].


Desired Properties of Counterfactual Explanations

The idea of counterfactuals is very closely related to adversarial pertubations, as both approaches manipulate the input to change a model’s outcome [6]. However, in contrast to adversarial attacks counterfactuals should be based on interpretable and semantically meaningful alterations. Therefore, several desiderata for counterfactuals can be defined. The models used to create the counterfactuals are constructed to fulfill these desired properties [14].

kmn

Figure 1: Two possible Counterfactuals for instance X. CF2 is closer to data manifold than CF1 [14].

  • Valitidy: The generated counterfactual is valid, if the given model correctly labels the counterfactual as the desired counterfactual class.
  • Actionability: The explainer model should just change actionable features. Immutable features should stay the same, so that the user can actually infer next steps and decisions from the given explanation. Furthermore, some models also consider causality, which describes the interdependence of two given features.
  • Similarity: The counterfactual should be close to the given instance given a distance function.
  • Sparsity: Usually explanations created by changing only few of the features are more understandable. Consequently, the input image should only be sparsely changed to create the counterfactual.
  • Data Manifold Closeness: A counterfactual image can only be be realistic, if it lies close to the observed data distribution.

The extent to which certain counterfactuals are able to fulfill these properties is often used used to evaluate the method used to generate the images.


Overview over approaches

In each approach, the explainer method is seperate from the model, that should be explained. Subsequently, depending on the data type and task requirements different approaches can be selected to build the explainer model. It is important to note, that the explainer does not need to be a deep learning model itself.

Figure 2: The Counterfactual explainer model receives the desired counterfactual property and the original data point as inputs. 

The different explainer methods can be sorted in three different strategies [6]:

  • Optimization: A loss function is constructed with respect to the desired properties and then optimized per samle.
  • Generative Model: A generative model is trained to generate images of the counterfactual class conditioned on the desired output class of the black box model.
  • Heuristic Search Strategy: Counterfactuals are found by making local choice based on a heuristic to iteratively minimize a cost function.
  • Instance-Based: Counterfactuals are retrieved out of the training data by choosing the most similar sample of the counterfactual class.

While older models are often instance-based, today most papers use optimization based approach or generative model, which is why the rest of this blog post will be focussed on these two methods.


Counterfactual images using a generative model

The process of generating counterfactuals using a generative model can be seen as an image-to-image translation problem, where the image generator is additionally conditioned on the desired out put class of the model, which should be explained. 

Jung et al. [8] create counterfactuals for CT Brain images to show increasing alzheimers disease signs using a conditional GAN. Mertes et al. [10], whose model is explained in more detail in the following, extend the CycleGAN architecture for generating counterfactuals. The main objective in the second paper is to be able to provide meaningful explanations of a binary classifier, that can detect pneumonia in X-ray images of lungs.

Model

The authors derive the process of generating counterfactuals from adversarial learning and are thus using GANs for the image generation. Specifically, they decided to extend the CycleGAN architecture first introduced by Zhu et al. [17] to the counterfactual modelling process. 

Figure 2: CycleGAN architecture with the addition of the classifier to generate counterfactuals [10].

The architecture can be described by a cycle of two GANs. The first GAN G receives inputs of images showing pneumonia (domain X), and learns to translate these to images without pneumonia (domain Y). The second GAN F translates from domain Y to domain X. In order to train the model a loss function is constructed consisting of several components. 

The first component of the loss function is the original GAN loss \mathcal{L}_{GAN} . This ensures, that the resulting images are close to the given data manifold and is thus necessary for interpretability.

Because this cycle is highly underconstrained, the authors also introduced a cycle-consistency loss \mathcal{L}_{Cycle}which tries to reconstruct the original image after a whole cycle. This loss ensures the desired property of minimality.

The counter loss \mathcal{L}_{counter} for each GAN is constructed such that the given classifier unambiguously classifies the image as the desired class and thus leads to validity. This can be done by using the logits of the classifier and pushing the counterfactual class probability to one and the original to zero.

\mathcal{L} = \mathcal{L}_{GAN} + \mathcal{L}_{Cycle} + \mathcal{L}_{counter}


The aggregate resulting loss function is optimized in an adversarial way by training both discriminator and generator. The resulting
images as can be seen here:

Figure 3. Upper images: Examples of class "No Pneumonia" with counterfactuals in class "Pneumonia". Lower Images translation in the other direction [10].

Results

To evaluate the counterfactual images, the authors tested, how many of the counterfactual images in the test set are correctly classified as the counterfactual class. Overall this test results in an accuracy of 94.68%. Moreover, a user study was conducted in order to assess the quality and interpretability of the generated counterfactuals as well as to prove its utility in application. This study focussed on five evaluation aspects: Explanation Satisfaction, Mental Models, Trust, Emotions and Self-Efficacy. As a result the participants were significantly more satisified with the counterfactual explanations compared to classical xAI methods LIME and LRP. Additionally, the system was perceived as more trustworthy.

Personal evaluation

In my opinion the main limitation of this approach is that, when transferring this model to a multi-class problem, several Cycle-GANs would have to be trained for each class combination, which results in high computational effort. This could potentially be avoided by using other methods such as conditional GANs. Nonetheless, this paper was one of very few containing such an extensive user study, which is particularly important in xAI. However, in my opinion comparisons not just to other xAI methods but also to other counterfactual explainers would have been a valuable addition to this paper. 

Direct per sample optimization 

In the paper Adversarial Counterfactual Attention for Classification and Detection in Medical Imaging [5] the authors deploy a counterfactual explainer to generate saliency maps. In addition they show,  that in a downstream task these saliency maps can also improve the performance of a second classifier. In comparison to the explainer used in the previous section, the counterfactuals are created in a direct optimization procedure and do not rely on a generative model. The method is applied to the IST-3 dataset that contains images on brain scans containing stroke lesions as well as the MosMed dataset, which consists of Covid-19 lung CT-scans.

Method

First, the authors train a classifer f to distinguish k different classes. For the creation of the counterfactual images the authors then directly formulate a loss function on a single input image. This loss function is subsequently optimized by gradienct descent with regards to the input. 

The main goal of the counterfactual model is to perturb the input in the direction of the respective counterfactual class in order to ensure validity. Therefore the the logit of the counterfactual class of classifier f can be directly accessed. This loss alone however would just produce an adversarial attack, as the change would not necessarily be interpretable. To solve this problem, the authors apply an approach similar to Cohen et. al. [2]. Instead of applying gradient descent on the image directly, the optimization is conducted in the latent space of an autoencoder, which was trained to reconstruct all training data. The input x is first mapped to the latent space using the encoder E(x) = z. Then a cross entropy loss is used to increase the output probability of the classifier f labeling the decoded image D(z) as class k. tk is the one hot encoded target vector of class k. In order to ensure minimality, this loss function is extended by the distance to the original image in the latent space.

L(z,t) = - \sum_{k=1}^{K} t_{k} log(f_{k}(D(z))) + \alpha || z-E(x) ||_{L1}


By optimizing this equation via gradient descent, a counterfactual for each class can be found for each given input image. The saliency maps are generated by highlighting the pixels with the biggest variations between counterfactual and original.

 Figure 4: Saliency maps for the IST-3 dataset generated using four different methods.

The authors additionally show, that saliency maps created by this method can be used to create a second, more robust classifier. This classifier receives the original image as well as the saliency maps per sample for each class and uses the saliency maps as attention maps in the classfication process.

Results


As the counterfactuals are not used as an for explanations in this case, the authors did not validate the resulting images with respect to the desired properties described in the beginning of this blog. Instead, they calculated the percentage of images, where the pixel with the highest value in each CT scan was in the correct region of the saliency map. As a result on the IST-3 dataset the authors got a score of 65.05%
compared to 58.39% using the method of Cohen et al. [2], who just applied a fixed shift in the latent space.

The second classifier, that was trained using the created saliency maps as additional input, achieved a higher accuracy on both IST-3 and MosMed dataset compared to other classification architectures using saliency maps (Fig. 5).

Figure 5: Results for the classifier using original images and saliency maps

Moreover, the authors are able to show, that the improvement of this classifier was particular high in case of smaller lesions in IST-3 dataset, which intuitively results from the classifier considering more local features through the application of saliency maps (Fig. 6).

Figure 6: The classifier performs particularly well on small lesions in the IST-3 dataset


Does the classifier need to be accessible?


In my opinion the main problem with this approach of creating counterfactuals is that the classifier directly appears in the loss function and thus needs to be accessible and differentiable. In order to avoid this restriction, an alternative solution was proposed by Looveren and Klaise [9]. The intuition in their approach is similar to the one used in protoypical networks [12]. Instead of including the logit output of the counterfactual class in the loss function, they utilize prototypes. A protoype is the average encoding of images, that the classifier maps to the counterfactual class. The distance between the current encoding z and the counterfactual class protoype can then replace the loss term using the classifier logit and the classifier can consequently also be non-differentiable. However, it is important to keep in mind that by exchanging these loss terms the shift in the latent space is reduced to a first order relationship.


Uncertainty estimation and out-of-distribution detection of counterfactuals

Delaney et. al. [3] focus on the uncertainty of counterfactuals and out-of-distribution detection. The motivation behind this can be seen in an example from the medical domain. Assuming a deep learning framework predicts the expected lung capacity given a specific treatment, a counterfactual would suggest alternative treatments, that would increase the lung capacity. For these alternative medications it is important that they are plausible and also not out-of-distribution.

The authors explicitly critizise the counterfactual models utilizing the classifier logits directly in the process of creating counterfactuals as these are often poorly calibrated and do not encode meaningful uncertainty. Additionally, they point out, that generative models, such as the CycleGAN described earlier, are often not feasable for complex datasets and too computationally expensive for application, because of which direct optimization methods should be preferred.

In order to measure the uncertainty of counterfactuals, they propose to use Monte-Carlo Dropout, assuming dropout layers are available in the network, as well as Trust Scores as introduced Jiang et al. [7]. Trust Scores are determined by computing the ratio of the distance from the sample to the closest high-density set of the original class to the high-density set of the counterfactual class in the training data. Therefore, higher  trust scores are more distribution compliant with the training data and come with a lower epistemic uncertainty.  In the experiments the authors tested three different counterfactual approaches with regards to their Trust Scores and MC-dropout on the MNIST and FashionMNIST datasets.


Figure 7: Achieved MC-Dropout and Trust Scores [3]

The first method, Nun-CF, searches for the closest counterfactual sample in the training data [1]. The second, W-CF, is directly derived by Wachter et al. [15]  and very similar to the method used in the previous section. The third is the approach by Looveren et Klaise [9] using guided prototypes that was also explained previously. In their results the authors find, that the second and third approach reach almost the same result with relatively low Trust Scores and low MC-dropout values. In comparison, the first method, which retrieves the closest training sample, reaches significantly higher dropout values. At this point, the authors do not discuss the results in detail, but suggest, that futher methods for uncertainty estimation and evaluation of counterfactuals are needed.

Why are many newer methods more uncertain?


In my opinion the results of this paper show one of the main pitfalls in counterfactual modelling. The low Trust Scores of the second and third approach could potentially be a consequence of the construction of the loss function. In both approaches the loss function tries to enforce the desired minimality by punishing the distance between the original sample and the counterfactual, which automatically keeps the created counterfactual close to the decision boundary of the classifier. Consequently, the uncertainty of classification for this counterfactual will on average be relatively high when considering Trust Scores.

 In contrast, when taking the closest sample of the training data, low uncertainty and also high Trust scores can be expected, as the sample is taken directly out of the training data distribution which of course yields the lowest possible epistemic uncertainty. In this sense, the results of this paper also show, that there is a trade off between minimality and the counterfactual being well grounded in the training data. This trade off can potentially be adjusted by varying parameters in the loss function, which in my opinion would have also resulted in more meaningful results in the paper.


Shut et al. [11] interestingly further hint at the importance of uncertainty for counterfactuals. They show that when using classifiers, that enable to measure uncertainty, such as deep neural network ensembles, the creation of a counterfactual can be derived from trying to minimize aleatric and epistemic uncertainty for the counterfactual. The resulting optimization turns out to be very efficient and also yields well interpretable results.


Conclusion


While all the previously described methods generate counterfactuals fulfilling several of the desired features, most of the papers still rely on application specific evaluation methods or even only a qualitative analysis. These evaluations are often just used to compare counterfactual explanations to other xAI techniques and different methods for counterfactual generation are only rarely compared which each other. In order to compare different counterfactual explainers, in my opinion there needs to be further research and a more generic evaluation measure. 

Another issue which was also discussed in this blog post is the uncertainty of the found counterfactuals, which has to be further investigated.

Nonetheless, counterfactuals provide a large potential for application in the medical domain, as trust in the decision systems is especially important in this sector and user studies such as seen before show that counterfactual explanations are often more intuitive compared to other xAI methods.


Bibliography

[1] Nugent C and Cunningham P (2005). A Case-Based Explanation System for Black-Box Systems. Artificial Intelligence Review 2005; 24:163–78. doi:10.1007/s10462-005-4609-5
[2] Cohen JP, Brooks R, En S, Zucker E, Pareek A, Lungren MP, and Chaud-hari A (2021). Gifsplanation via Latent Shift: A Simple Autoencoder Approach to Progressive Exaggeration on Chest X-rays. CoRR 2021; abs/2102.09475. arXiv: 2102 . 09475. Available from: https://arxiv.org/abs/2102.09475
[3] Delaney E, Greene D, and Keane MT (2021). Uncertainty Estimation and Out-of-Distribution Detection for Counterfactual Explanations: Pitfalls and Solutions. CoRR 2021; abs/2107.09734. arXiv: 2107.09734. Available from:https://arxiv.org/abs/2107.09734
[4] Elliott A, Law S, and Russell C (2019). Adversarial Perturbations on the Perceptual Ball. CoRR 2019; abs/1912.09405. arXiv: 1912.09405. Available from: http://arxiv.org/abs/1912.09405
[5] Fontanella A, Antoniou A, Li W, Wardlaw J, Mair G, Trucco E, and Storkey A (2023). ACAT: Adversarial Counterfactual Attention for Classification and Detection in Medical Imaging. 2023. arXiv: 2303.15421 [eess.IV]
[6] Guidotti R (2022). Counterfactual explanations and how to find them: literature review and benchmarking. Data Mining and Knowledge Discovery 2022 Apr :1–55. doi: 10.1007/s10618-022-00831-6
[7] Jiang H, Kim B, Guan MY, and Gupta M (2018). To Trust Or Not To Trust A Classifier. 2018. arXiv: 1805.11783 [stat.ML]
[8] Jung E, Luna M, and Park SH (2023). Conditional GAN with 3D discriminator for MRI generation of Alzheimer’s disease progression. Pattern Recognition 2023; 133:109061. doi: https://doi.org/10.1016 /j.patcog.2022.109061. Available from: https://www.sciencedirect.com/science /article/pii/S0031320322005416
[9] Looveren AV and Klaise J (2019). Interpretable Counterfactual Explanations Guided by Prototypes. CoRR 2019; abs/1907.02584. arXiv: 1907.02584. Available from: http://arxiv.org/abs/1907.02584
[10] Mertes S, Huber T, Weitz K, Heimerl A, and Andr´e E (2022). GANterfactual—Counterfactual Explanations for Medical Non-experts Using Generative Adversarial Learning. Front. Artif. Intell. 2022; Volume 5 - 2022. Available from: https://doi.org/10.3389/frai.2022.825565
[11] Schut L, Key O, McGrath R, Costabello L, Sacaleanu B, Corcoran M,and Gal Y (2021). Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties. CoRR 2021;abs/2103.08951. arXiv: 2103 . 08951. Available from: https://arxiv.org/abs/2103.08951
[12] Snell J, Swersky K, and Zemel RS (2017). Prototypical Networks for Few-shot Learning. CoRR 2017; abs/1703.05175. arXiv: 1703 . 05175. Available from: http://arxiv.org/abs/1703.05175
[13] Tjoa E and Guan C (2019). A Survey on Explainable Artificial Intelligence (XAI):Towards Medical XAI. CoRR 2019; abs/1907.07374. arXiv: 1907.07374. Available from: http://arxiv.org/abs/1907.07374
[14] Verma S, Dickerson JP, and Hines K (2020). Counterfactual Explanations for Machine Learning: A Review. CoRR 2020; abs/2010.10596. arXiv: 2010.10596. Available from: https://arxiv.org/abs/2010.10596
[15] Wachter S, Mittelstadt BD, and Russell C (2017). Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. CoRR 2017; abs/1711.00399. arXiv: 1711.00399. Available from: http://arxiv.org/abs/1711.00399
[16] Wang C, Li J, Zhang F, Sun X, Dong H, Yu Y, and Wang Y (2020). Bilateral Asymmetry Guided Counterfactual Generating Network for Mammogram Classification. CoRR 2020; abs/2009.14406. arXiv: 2009.14406. Available from: https://arxiv.org/abs/2009.14406
[17] Zhu J, Park T, Isola P, and Efros AA (2017). Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. CoRR 2017; abs/1703.10593. arXiv: 1703.10593. Available from: http://arxiv.org/abs/1703.10593


  • Keine Stichwörter