This is my humble attempt at making the readers understand the article titled "Interpreting Graph Neural Networks For NLP With Differentiable Edge Masking".
I thank the authors Michael Sejr Schlichtkrull, Nicola De Cao, and Ivan Titov for this beautiful work of art [4].
Motivation
Have you ever wondered why you took a particular decision in your life? What was the reasoning behind your decision? If you know the answer to these questions, you probably understand yourself better and will trust your future decisions more. For many years, I have been wondering how a deep learning model arrives at a prediction. What exactly is happening inside the black box that we call a model where all the magic happens and voila, we have a prediction at the end of it? This paper beautifully describes a method through which we can find answers to the above questions about a deep learning model, particularly about Graph Neural Networks (GNNs).
Why do we need Interpretation?
Real world tasks are complex and so is the data associated with it. When such complex systems are part of any experiment, there will be some amount of uncertainty involved and uncertainty leads to trust issues with the model. How can we trust a model when we don't understand what's happening inside the model? Classification accuracy in effect gives an incomplete description about why a deep learning model made those predictions. To interpret the model predictions, we need to understand why the prediction was made [1].
Introduction
As defined by Miller in [3] "Interpretability is the degree to which a human can understand the cause of a decision". Post-hoc method of interpretability refers to the application of interpretation methods after model training and in this work, a post-hoc method is used to interpret the predictions of GNNs to identify unnecessary edges. Deep learning models are usually less transparent and involve calculating gradients given millions of weights which further adds to the complexity of the task at hand. The main intuition behind this work is to train a classifier in a fully differentiable fashion such that for every edge in every layer, it predicts whether that edge can be dropped. And for this purpose, they employ stochastic gates and encourage sparsity through the L0 norm, while applying the proposed attribution method for tasks such as question answering and semantic role labeling [4].
Erasure Search
Erasure search is a straightforward approach to accomplish interpretation. It basically searches for irrelevant edges in a graph and then erases them. Erasure search looks for a subset of characteristics that can be totally removed from the model without impacting predictions. This ensures that the model ignores all information about the deleted characteristics. This differs from systems that rely on attention-based methods or back-propagation techniques, which do not ensure that the model ignores low-scoring elements. However, erasure search method is susceptible to hindsight bias. Hindsight bias can be mitigated by Amortization [4].
15]
[What is Hindsight Bias?
In the context of medical applications, a study [2] shows that after a tumor has been diagnosed, in hindsight, 90 percent of lung cancers and 70 percent of breast cancers that had previously not been recognized, are observed on radiographs. As an example, let us assume that a trainee doctor overlooked the tumor in the radiograph and the tumor was discovered later during review. Later, a supervising doctor (who has knowledge about the existence of the tumor) and who was not involved in the initial diagnosis would be hindsight biased and would unfairly evaluate the trainee doctor's ability to diagnose cancer. In other words, being informed about the location of the tumor, the supervisor overestimates the likelihood that she would herself have found the tumor initially. Consequently, she might underestimate the abilities of the young trainee doctor, against the backdrop of her own inflated sense of ability [8]. This is an example of hindsight bias in real world applications and it is known as the curse of knowledge. This highly pervasive bias which is a form of distortion of cognition and information processing clouds a model's decision making ability to recall prior expectations for outcomes. Outcomes that seemed impossible to anticipate suddenly seem obvious once the truth is revealed [5].
Related Work
GNNExplainer [6] is one of the recent papers that focused on developing interpretability techniques for GNNs and is related to the proposed Graph Mask method introduced in this paper. However, separate optimization for each example leads to hindsight bias and decrease in faithfulness. The work closest to the proposed Graph Mask is [7] and is the baseline for the experiments conducted in this research work. The proposed Graph Mask mitigates the drawbacks of the previous works mentioned in this paper.
Proposed Method - The Graph Mask
The Graph Mask defined in this work achieves the same benefits as erasure search, but in a scalable manner. The proposed Graph Mask is a differentiable form of subset erasure, where, instead of finding an optimal subset to erase for every given example, an erasure function is learned which predicts for every edge <u, v> at every layer k whether that connection should be retained. Sparse stochastic gates are used to enable gradient-based optimization. The main goal of the Graph Mask is to search for superfluous edges (u, v) and the corresponding messages m_{u,v}^{(k)}and replace them with a learned baseline b^{(k)} through a binary choice z_{u,v}^{(k)} . The resulting message is represented by
\widetilde{m}_{u,v}^{(k)} = z_{u,v}^{(k)}.m_{u,v}^{(k)} + b^{(k)}.(1 - z_{u,v}^{(k)}) |
To overcome hindsight bias and also to make the equation tractable, a parameterized form of z_{u,v}^{(k)} is computed using the erasure function equation given by
z_{u,v}^{(k)} = g_{\Pi}(h_{u}^{(k)},h_{v}^{(k)},m_{u,v}^{(k)}) |
Here π denotes the parameters of g, which is implemented as a single layer neural network.
Erasure Function Architecture
Message function:
Node representation:
Edge representation at layer k:
Scalar location parameters for the hard concrete distribution based on edge representation:
Vertex embeddings matrix:
Modified faster alternative to the scalar location parameters based on a bi-linear product:
Representation matrix for the vertices in the masked model:
Amortization
Instead of solving a particular problem one at a time independently, a family of problems is first defined and then the knowledge from related problems is transferred into the new problems. This approach is called Amortization. Amortized version of Graph Mask is used in this work to prevent hindsight bias in GNNs. On the contrary, the non-amortized version of Graph Mask is susceptible to hindsight bias and the performance of both the architectures are compared.
Contributions
In erasure search, optimization happens individually for each example. This can result in a form of over-fitting where even non-superfluous edges are aggressively pruned. This issue had been addressed by amortizing parameter learning over a training dataset. This strategy avoids hindsight bias and avoids the readout bottleneck introduced in [7].
The contributions of this paper are as follows:
- The paper presents a novel interpretation method for GNNs which can potentially be applied to any end-to-end neural model which has GNN as a component.
- The authors use artificial data to demonstrate the drawbacks of the closest existing method and show how their method addresses those drawbacks and improves faithfulness of the model.
- The paper describes the use of Graph Mask to analyze GNN models for semantic role labeling and multi-hop question answering.
Analyze a Data Point using Graph Mask
- Train the erasure function g.
- Execute the original model over the data point to obtain h_{u}^{(k)}, h_{v}^{(k)}, and m_{u,v}^{(k)}.
- Compute gates for every edge at every layer and execute a sparsified version of the model.
- For the first layer, the messages of the original model are gated according to \widetilde{m}_{u,v}^{(k)}.
- For subsequent layers, masked messages are aggregated to obtain vertex embeddings {h}_{v}^{'(k)} which are further used to obtain the next set of masked messages.
- The parameters of the original model are kept constant.
- π and baseline vectors b_{1},......,b_{k} are the only learned parameters of the Graph Mask.
- Without loss of generality, we can interpret the masked messages as superflous.
[9]
Proposed Objective Function
GNN: f
No. of layers: L
Graph: G
Input Embeddings: \chi
Informative sub-graphs: G_{S} = \{G_{S}^{(1)},...........,G_{S}^{(L)}\}
Assumption: $G_{S}^{(k)}\subseteq G \forall k \in 1,.....,L$
Constraint: $f(G_{S}, \chi) \approx f(G, \chi)$
$f(G_{S}, \chi) $ denotes a forward pass where the graph G may vary for each layer.
$f(G, \chi) $ graph G is the same across all layers
f is differentiable everywhere, so it is not possible to guarantee equality between $f(G_{S}, \chi) $ and f(G, \chi)
The proposed objective function is given by
\underset{\lambda }{max} \, \underset{\pi ,b}{min} \: \left ( \sum_{k=1}^{L} \sum_{(u,v)\in \varepsilon }^{L} \mathbf{1}_{\[\mathbb{R}\neq 0\]}(z_{u,v}^{(k)})\right ) + \lambda (D_{*} \[[ f(G, \chi)\parallel f(G_{S}, \chi) \]] - \beta) |
where \mathbf{1} is the indicator function, \lambda \in \mathbb{R}_{\geq 0} is the Lagrange multiplier, D_{*} is the divergence, and \beta \in \mathbb{R}_{> 0} is the tolerance level.
Limitations of the Objective Function
- L0 is discontinuous and has zero derivatives almost everywhere.
Outputting a binary value needs a discontinuous activation.
The objective function is not differentiable and as a consequence, gradient-based optimization cannot be used.
Mitigation steps
- Use of a mixed discrete-continuous distribution on the closed interval [0,1] which assigns a non-zero probability to exact zeros and admits continuous outcomes in the unit interval.
- Reparameterization trick to calculate an unbiased and low variance gradient.
In the end, Graph Mask converges to a distribution where scores in expectation assume near binary values.
Experiments
Synthetic Data
- Experimental Setup:
- Star graph G with a single centroid vertex v_{0}
- Leaf vertices: v_{1},...,v_{n}
- Edges: (v_{1}, v_{0}),....,(v_{n}, v_{0})
- Every edge (u,v) is assigned a color from c_{u,v} \in C
- Query: <x,y> \in C \times C
- Task: Predict if (No. of edges assigned x > No. of edges assigned y )
- Examples are randomly generated with 6 to 12 leaves
- One layer R-GCN [20] is applied
- Trained model perfectly classifies every example
- Gold standard for faithfulness: for x > y, all edges of type x and y should be retained, and all others should be discarded
- Graph Mask is compared to four baselines.
- Erasure search [10], integrated gradients [11], information bottleneck [7], and GNNExplainer [6]
- Integrated gradients and information bottleneck approaches were adapted for this experiment as they were not designed for graphs.
- GNNExplainer and Information Bottleneck do not make hard predictions. So, a gate was defined with a threshold.
- For Integrated Gradients, the attributions were normalized to the interval [-1,1]. Absolute value of this was taken and was subject to a threshold. The threshold values were taken from the set t \in \{0.1,....,0.9\} to maximize F1 score on validation data
- Observations:
- Erasure search, GNNExplainer, non-amortized Graph Mask: perfect model performance and not faithful to original model behaviour
- Scalar attribution scores vary greatly across samples with integrated gradients approach. Hence a single threshold cannot be defined.
- This type of overfitting to the objective was solved by amortization.
- Results: Click on the images to see the content.
- Configuration:
- Training:
- Testing:
- Example:
- Experimental Setup:
Question Answering
- Dataset: WikiHop Dataset [12] and preprocessing script from [13]
- Experimental Setup: Amortized Graph Mask – analyze predictions for a real model
Complex – no human gold standard for attribution
- Given a query sentence and a set of context documents
- Find the entity within the context which best answers the query
- Nodes: mentions of entities within the query and context
Edges: 4 types
- String match (MATCH)
- Document level co-occurrence (DOC-BASED)
- Coreference solution (COREF)
- Absence of any other edge (COMPLEMENT)
- Model:
- 2 layer BiLSTM reading the query, 3 layers of R-GCN with shared parameters
- Node representations at the bottom layer: concatenate query representation to embeddings for each mention in question
- GloVe based model is used
- Max-Pooling: combine mention representations to entity representations
- Observations:
- 27% of edges are retained when Graph Mask is used with amortization - majority retention occurs at the bottom layer
- Without amortization, of the retained edges, 91% are from the top layer and only 0.4% from the bottom layer
- Proportion of edges that occur on paths from mentions of the query increases drastically by layer, from 11.8% at layer 0, to 42.7% at layer 1, and upto 73.8% at the top layer.
- The GNN used is responsible for two things here:
- propagating evidence to the predicted answer through the graph
- propagating evidence to alternate candidates
- The majority of paths were seen to take one of two forms:
- COMPLEMENT edge followed by either a MATCH or a DOC-BASED edge (22%)
- COMPLEMENT edge followed by two MATCH or DOC-BASED edges (52%)
- MATCH or DOC-BASED edges in the bottom layer tend to represent one-hop paths rather than being the first edge on a long path.
- An observation about the sub-graphs retained by Graph Mask is that pairs of an edge and its inverse are both judged to be either superflous or non.superflous (individually in each layer). This kind of 'undirected' exchange between mentions which results in enriched representations is an important aspect to consider when applying this approach to different datasets.
- Results:
Semantic Role Labeling
- Dataset: English CoNLL-2009 shared task dataset [14]
- Experimental Setup: Identify arguments of a given predicate and assign them to semantic roles (subject, object, predicate)
- Labeling: assigns labels to words or phrases – indicate their semantic role in the sentence
- Semantic Role: underlying relationship that a subject has with the main verb in a clause
- Models:
- BiLSTM + GNN:
- Edges often directly connect the predicate to the predicted roles
- Or edges connect predictions to tokens close to the predicate, easily reachable via the LSTM
- Limitations:
- LSTM struggles to propagate information where predicting roles are far away from the predicate
- Reliance on paths decreases as distance to the predicate increases (only for nominal predicates)
- Mitigation steps:
- Use GNN only model
- GNN only model:
- Use paths in the graph
- Relying on the entire path or partially relying on the last several edges in the path
- Observations:
- Reliance on paths increases as the distance to the predicate increases
- Longer paths are useful in both the models
- At a lower rate for nominal predicates in LSTM + GNN model
- BiLSTM + GNN:
- Results:
- Experimental Setup: Identify arguments of a given predicate and assign them to semantic roles (subject, object, predicate)
Conclusion
Graph Mask – a novel post-hoc interpretation method applicable to any GNN model (domain agnostic) has been introduced.
By learning end-to-end differentiable hard gates for every message and amortizing over the training data, it was shown that Graph Mask is:
- Faithful to the studied model
- Scalable to modern GNN models
- Capable of identifying both how edges and paths influence predictions.
Application and analysis of predictions of two NLP models: Outperformed previous works in this domain and also addressed their limitations along with a thorough analysis.
- Question Answering model
- Semantic Role Labeling model
Graph Mask was used to answer these two important questions:
- Which edge types these models rely on?
- How they employ paths while making predictions?
My Review of The Paper
I, personally, found the paper to be a bold attempt at fixing hindsight bias and after observing the results of the experiments, am of the opinion that they have succeeded in their pursuit. The difference in the metrics like F1 score between Amortized Graph Mask and the GNNExplainer in addressing a similar issue is a testament to the fact that the proposed method is robust to real world data and is also scalable. I was particularly impressed with the results obtained on the Question Answering and Semantic Role Labeling experiments as they are often quite difficult to interpret. One particular result which caught my attention was that in the sub-graphs retained by Graph Mask, pairs of an edge and it's inverse are both judged to be either superfluous or non-superfluous. I am not particularly sure if the same thing can be said about non-textual datasets. The authors also do not delve into much detail about this claim and it would have been great if they had mentioned how this aspect of their result is indeed domain agnostic as they claim it to be. I would have also found it helpful to have more details about how amortization is implemented in the Graph Mask. There are works that are similar to this and it would have been great if the authors had drawn some comparisons with such related work and also explain why Graph Mask outperforms previous related works (other than the ones mentioned in this paper) in this domain. Other than that, I found the topics mentioned to be fairly coherent throughout the paper. It is a well structured and well written paper with references to a lot of related work. The Erasure function architecture and also the process of analyzing a data point through the Graph Mask by proposing a new objective function were certainly the highlights of the paper. The authors have made the source code available on Github [16] for testing and this was helpful in understanding their approach better. Since this method can be used with any network which encompasses a GNN, applications of this method could also be explored in the medical image analysis field. It was a pleasure to read this paper.
Future Work
Since this novel interpretation method can be applied to any network with a GNN, I plan to pursue this topic further and apply it to medical applications involving brain MRI images and EEG signals. In particular, to create a graph of the brain to understand some use cases in neuro-degenerative disorders like Multiple Sclerosis, Parkinson's disease, and Alzheimer's disease. Applications in Psychiatry could also be explored as the information available about the brain is less and deep learning along with the required interdisciplinary knowledge in medicine can extend the frontiers of research in this field.
I plan to explore the following data sets for my projects.
https://www.humanconnectome.org/study/hcp-young-adult/data-releases
https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation
https://www.ukbiobank.ac.uk/
http://www.miccai.org/about-miccai/student-board/educational-challenge/
Software and Hardware Requirements
- Python 3.7+ [18]
- Pytorch-geometric [17]
- Pytorch [19]
- Single Titan-X GPU or any other GPU
Acknowledgement
I want to thank Dr. Anees Kazi for guiding me in the right direction for the seminar.
References
[1] Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable", 2019. https://christophm.github.io/interpretable-ml-book/.
[2] Leonard Berlin. Statement of leonard berlin, m.d., to the u.s. senate committee on health,education labor and pensions: Mammography quality standards act reauthorization. 2003
[3] Miller, Tim. "Explanation in artificial intelligence: Insights from the social sciences." arXiv Preprint arXiv:1706.07269. (2017)
[4] Michael Sejr Schlichtkrull, Nicola De Cao, & Ivan Titov (2021). Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking. In International Conference on Learning Representations.
[5] Shaudi Mahdavi, "Hindsight Bias Impedes Learning." Proceedings of Machine Learning Research 58 (2016) 111-127
[6] rexying, dtsbourg, jiaxuan, marinka, jure. "GNNExplainer: Generating Explanations for Graph Neural Networks." 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.
[7] Karl Schulz, Leon Sixt, Federico Tombari, Tim Landgraf (2020)."RESTRICTING THE FLOW:INFORMATION BOTTLENECKS FOR ATTRIBUTION." In International Conference on Learning Representations.
[8] MADARÁSZ, KRISTÓF. “Information Projection: Model and Applications.” The Review of Economic Studies, vol. 79, no. 3, 2012, pp. 961–985. JSTOR, www.jstor.org/stable/23261376.
[9] https://ai.googleblog.com/2020/03/more-efficient-nlp-model-pre-training.html
[10] Jiwei Li, Will Monroe, and Dan Jurafsky. Understanding neural networks through representation erasure. arXiv preprint arXiv:1612.08220, 2016
[11] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 3319–3328. JMLR. org, 2017.
[12] Johannes Welbl, Pontus Stenetorp, and Sebastian Riedel. Constructing datasets for multi-hop reading comprehension across documents. Transactions of the Association for Computational Linguistics, 6:287–302, 2018
[13] De Cao, I. (2019). Question Answering by Reasoning Across Documents with Graph Convolutional Networks. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers) (pp. 2306–2317). Association for Computational Linguistics.
[14] Haji\vc, Y. (2009). The CoNLL-2009 Shared Task: Syntactic and Semantic Dependencies in Multiple Languages. In Proceedings of the Thirteenth Conference on Computational Natural Language Learning (CoNLL 2009): Shared Task (pp. 1–18). Association for Computational Linguistics.
[15] https://www.outsystems.com/blog/posts/graph-neural-networks/
[16] https://github.com/MichSchli/GraphMask
[17] https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html
[18] https://www.python.org/downloads/
[20] Michael Schlichtkrull, Thomas N. Kipf, Peter Bloem, Rianne van den Berg, Ivan Titov, & Max Welling. (2017). Modeling Relational Data with Graph Convolutional Networks.
Personal Profile
Yadunandan Kini - Google Scholar Profile
Yadunandan Kini - Linkedin Profile