In a Nutshell
In this paper, interpreting Graph Neural Networks (GNNs) means that we want to identify which edges in the graph the GNN relies on, and at which layer they are used. Therefore, given a trained GNN, we learn a classifer for every edge in every layer to predicts if that edge can be dropped. In order to train the classifier in a fully differentiable fasion, we employ stochastic gates and L_0 norm to encourage sparsity. The basic idea can be shown in the following figure.
(u, v) in k-th layer, we take the representations of head entity, tail entity, and relationship, concatenate and pass them to a classifer to see if this edge should be dropped or retained. Repeat the same procedure for each edge in each layer, we can get a sparsified version of original GNN, which can still have the same performance. In this case, the remaining edges explain how the original GNN gets the predictions.
Figure 1: Given a trained GNN, for an edgeIntroduction
Motivation
In NLP tasks, graphs can be used to represent syntactic and semantic trees, co-reference structures, knowledge bases etc. And Graph Neural Networks (GNNs) have shown great power in dealing with graph data structure due to the ability of incorporating relational data, which have been applied to a range of NLP tasks, including relation extraction, question answering, syntactic and semantic parsing tasks, summarization, machine translation and so on.
For such complex model, it is difficult to understand the reasoning logic or the basis for prediction making. However, in NLP task, we desire to know which linguistic information contributes most to the prediction. Only when we have an insight into the mechanism and the explanation seems plausible to us, are we willing to use and trust products equipped with GNN models in NLP tasks. In addition, it is easier to focus on the correct part of the code to analyse the error and improve the model with the help of such indicating information. Therefore, the paper proposes a post-hoc method to analyse trained GNNs, by inspecting which edges at each layer are used to make predictions.
Besides, the authors formulate three criteria for an interpretation method, which should be:
- able to identify relevant paths, since paths reveal the intrinsic reasoning logics for the data instances
- computational efficient, otherwise it is not applicable to large dataset
- as faithful as possible, because we want a true explanation. You can get an idea of fake interpretation with the example in hindsight bias
Contribution
- present a novel interpretation method for GNNs, applicable potentially to any end-to-end neural model which has a GNN as a component
- demonstrate the shortcoming of exisitng methods using synthetic data and show how GraphMask addresses those shortcoming and improves faithfulness
- apply GraphMask to analyse GNN models for two NLP task: multi-hop question answering and semantic role labeling
Some Definitions
- hindsight bias
If we want to predict if there are more red arrows than blue arrows for the center node in Figure 2, all red and blue edges should be retained and the green edge is not related. But if we just keep the red lines, even if we just keep only one red line, the prediction is still the same as before. But obviously this sub-graph cannot be regarded as explanation of the original model. We call the phenomenon hindsight bias, which is some kind of overfitting. Normally it is caused by "knowing the future". The algorithm could find alternatve smaller sub-graph to retain the original prediction, while this fake optimal sub-graph cannot show the contributing edges truely and faithfully.
Figure 2: Hindsight bias
Methodology
Graph Neural Network
In this part we would firstly review the basic idea of GNNs. For a given Graph \mathcal{G} = \lang \mathcal{V}, \mathcal{E} \rang, a GNN computes representations h_u^{(k)} for each node u \in \mathcal{V} in k-th layer using the information from neighboring nodes and itself. Two important components are a message function M that processes the relational data and an aggregation function A that gathers neighboring information to update itself.
(1) | m_{u,v}^{(k)} = M^{(k)} \left( h_u^{(k-1)}, h_v^{(k-1)}, r_{u, v} \right) |
(2) | h_v^{(k)} = A^{(k)} \left( \left\{ m_{u, v}^{(k)} : u \in \mathcal{N} (v) \right\} \right) |
GraphMask
Derivation process
We would like to assign a binary gate z_{u, v}^{(k)} to every edge (u, v) at layer k to indicate whether it is superfluous (z_{u, v}^{(k)} = 0) or not. The remaining edges consist of the optimal subset \mathcal{G_S}^* of the original graph \mathcal{G}. But if we change the graph structure, i.e. drop edges (the average degree for nodes would decrease), the prediction will be affected a lot. Instead of dropping edges completely, we replace it with a learned baseline b^{(k)}
(3) \widetilde{m}_{u, v}^{(k)} = z_{u, v}^{(k)} \cdot m_{u, v}^{(k)} + b^{(k)} \cdot (1 - z_{u, v}^{(k)})
In order to find the optimal subset \mathcal{G_S}^*, it is impossible to list all the possibilities due to the combinatorial explosion. In stead, we compute z_{u, v}^{(k)} through a simple function g_{\pi}(), called the erasure function, learned once for every task across data points. \pi is the parameters of g, which is implemented as a single-layer nerual network.
(4) z_{u, v}^{(k)} = g_{\pi}\left( h_u^{(k)}, h_v^{(k)}, m_{u,v}^{(k)} \right)
How to avoid hindsight bias
- amortization: the parameters \pi are trained on multiple datapoints, instead of optimizing gate value z_{u, v}^{(k)} for each prediction, which can easily lead to overfitting.
- without look-ahead: each z_{u, v}^{(k)} is computed using only information in the original model, i.e. representations of nodes and messages, without knowing the corresponding prediction, therefore less likely to cheat by getting the right answer.
note: the alternative to amortization is to choose \pi independently for each gate, without any parameter sharing across gates. We call this strategy non-amortized version of GraphMask.
Inference process
Now we want to analyse a data point x with trained erasure function g:
- execute the original model over p to get node representations h_u^{(k)}, h_v^{(k)} and message m_{u,v}^{(k)},
- compute gates for every edge at layer, from the first layer to the top layer,
- upper layer aggregates new masked messages from previous layer,
- execute the sparsified version of the model and get the new prediction y',
- when y' \approx y, we can interpret masked messages as superfluous.
Training (Parameters Estimation)
Mathematical task definition
- given a GNN f of L layers, a graph \mathcal{G}, and input embeddings \mathcal{X}, the task is to identify a set \mathcal{G}_S = \{ \mathcal{G}_S^{(1)}, \dots, \mathcal{G}_S^{(L)} | \mathcal{G}_S^{(k)} \subseteq \mathcal{G} \quad \forall k \in 1, \dots, L\}, such that it maintains f(\mathcal{G}_S, \mathcal{X}) \approx f(\mathcal{G}, \mathcal{X}) with minimal number of edges
Loss function
- we formulate the problem in the language of constrained optimization and employ a method that enables gradient descent such as Lagrangian relaxation
objective: we want to minimize the number of non-zeros predicted by g, which can be formulated as minimizing the L_0 norm (i.e. the total number of edges that are maeked) (remember that \pi, b are the model parameters)
(5) \min_{\pi, b} \sum_{\mathcal{G}, \mathcal{X} \in \mathcal{D}} \left( \sum_{k=1}^L \sum_{(u,v) \in \mathcal{E}} 1_{[\mathbb{R}\neq 0]} (z_{u,v}^{(k)}) \right) constraint: we also want to minimize the different between f(\mathcal{G}, \mathcal{X}) and f(\mathcal{G}_S, \mathcal{X}), denoted as D_{*}[f(\mathcal{G}, \mathcal{X}) \| f(\mathcal{G}_{S}, \mathcal{X})], which can be expressed as a constraint under threshold \beta, i.e.
(6) s.t. \quad D_{*}[f(\mathcal{G}, \mathcal{X}) \| f(\mathcal{G}_{S}, \mathcal{X})] - \beta \leq 0 In the appendix, it is mentioned that \beta=0.03 gets the best result.
the complete loss function is formulated as
(7) \max_{\lambda} \min_{\pi, b} \sum_{\mathcal{G}, \mathcal{X} \in \mathcal{D}} \left( \sum_{k=1}^L \sum_{(u,v) \in \mathcal{E}} 1_{[\mathbb{R}\neq 0]} (z_{u,v}^{(k)}) \right) + \lambda (D_{*}[f(\mathcal{G}, \mathcal{X}) \| f(\mathcal{G}_{S}, \mathcal{X})] - \beta) where 1 is the indicator function and \lambda \geq 0 denotes the Lagrange multiplier
Sparse Relaxation
But the issue is that L_0 norm is discontinuous as shown in the figure, and hence the loss function is not differentiable. We use the sparse relaxation technique, specifically, the Hard Concrete distribution.
The advantage of Hard Concrete distribution:
- assigns a non-zero probability to exact zeros
- can be back-propagated via the reparameterization trick
Derivation of Hard Concrete distribution:
- a Binary Concrete distribution p_C is shown as the blue line in (a), which is parameterized by \gamma, \tau (\gamma controls if the curve skews to 0 or 1, \tau determines the concentration of the distribution)
- p_C is scretched outside the interval [0,1], as shown with the yellow curve p_{SC} in (a)
- p_{SC} is rectified back to the interval [0,1], resulting to the final hard concrete distribution p_{HC}
- p_{HC} is a distribution over the close interval [0, 1] with non-zero mass at 0 and 1
Full version of erasure function g:
inside g we actually do the following steps:
pass the information along the edge (u, v) in layer k to multi-layer perceptron, getting \gamma_{u,v}^{(k)}:
\gamma_{u, v}^{(k)} = \text{MLP} \left( h_u^{(k)}, h_v^{(k)}, m_{u, v}^{(k)} \right) sample z_{u, v}^{(k)} from the Hard Concrete distribution parameterized by \gamma_{u,v}^{(k)}:
z_{u, v}^{(k)} \sim \text{HardConcrete} \left( \gamma_{u, v}^{(k)}, \tau \right)
Experiments
Synthetic Experiment
setting
- dataset: synthetic data, a star graph \mathcal{G}, i.e. a center node and surrounding leaf nodes, with colored edges c \in C, where C is the set of all colors
- task: for query \lang x, y \rang \in C \times C, find out if there are more edges assigned with x than those with y, the example query is \lang \text{black}, \text{blue} \rang
- model: one-layer R-GCN
Results
- the amortized version get the best result: all the related edges are retained and all other edges are dropped;
- hindsight bias appears in erasure search, GNNExplainer and non-zmortized version GraphMask;
- Integrated Gradients and Information Bottleneck still keep the superfluous edges;
- from the table we can see the recall of GraphMask is 100%, which means that all the necessary edges are retained.
Figure 6: model comparison from synthetic dataset
Figure 7: precision and recall of different models in the synthetic task
Question Answering
Setting
- dataset: WikiHop
- task:
given a query sentence and a set of context documents, to find the entity within the context which best answers the query.
Figure 8 shows an example from the dataset, we are given several support documents and a query asking to which country does Thorildsplan belong. Also we have some candidate answers, which are entities in those support documents.We would like to take the entity in candidates with the highest score w.r.t. some evaluation metric as the answer.
- model:
- edge type: DOC-BASED (solid line, co-appear items in one document), MATCH (dashed line, matched by string matching), COREF (red, co-reference), COMPLEMENT (make the whole graph connected)
- node: mentions of entities, represented with query-dependent mention encoding \hat{x}_i = f_x (q ,x_i), h_i^{(0)} = \hat{x}_i
- message function: u_i^{(l)} = f_s(h_i^{(l)}) + \frac{1}{|\mathcal{N}_i|} \sum_{j \in \mathcal{N}_i} \sum_{r \in \mathcal{R}_{ij}} f_r (h_j^{(l)})
- gate: a_i^{(l)} = \sigma (f_a ([u_i^{(l)}, h_i^{(l)}]))
- aggregation function: h_i^{(l+1)} = \phi(u_i^{(l)}) \odot a_i^{(l)} + h_i^{(l)} \odot (1 - a_i^{(l)})
- prediction: \max_{i \in \mathcal{M}_c} f_o([q, h_i^{(L)}])
Figure 8: example data format in WikiHop dataset
Figure 9: graph structure in the question answering model
Results
- We can use only 27% of original edges to behave closely as the original model, i.e. 0.4% accuracy decrease;
- all those retained edges contribute to the final prediction, because when we continue dropping these remaining edges, the accuracy also decreses quickly, as shown in Figure 10;
- in figure 12 we can see that most of the retained edges are in the bottom layer, which means that these edges are more important than the upper layers. This can be verified by removing individual layer and checking the performance changes. The result is shown in figure 11. When we drop the first layer, the model accuracy get affected at most, which also demonstrates that the bottom layer plays a more important role in the GNN prediction;
- in terms of different edge types, COREF and COMPLEMENT are the most interesting part. COREF seems superfluous in the whole model, since only 4.4% edges are retained in just the first layer. Sometimes COREF and MATCH edges are connecting the same entities, but it can be inferred from the above result that the model relies more on MATCH edges than COREF edges, as way more MATCH edges are retained in the sparsified model. In addition, although all COMPLEMENT edges in the second and final layers are thrown away, much higher percentage of COMPLEMENT edges are retained in the first layer, indicating that it contributes a lot to the model prediction. For MATCH and DOC-BASED edges, they are quite equally distributed among different layers;
- considering the percentage of different edge types, it is obvious that most reasonning paths start from COMPLEMENT edges, directing to other suppot documents, and then follow MATCH or/and DOC-BASED edges in the same document to find the correct answer.
Figure 10: accuracy when we drop each 1/4 of the retained edges
Figure 11: accuracy changes when individual layer is removed
Figure 12: percentage of retained edges in each layer
Semantic Role Labeling
Setting
- dataset: CoNLL-2009
- task:
Given a sentence, the task consists of analyzing the propositions expressed by some target verbs of the sentence. In particular, for each target verb all the constituents in the sentence which fill a semantic role of the verb have to be extracted. Typical semantic arguments include Agent, Patient, Instrument, etc. and also adjuncts such as Locative, Temporal, Manner, Cause, etc.
For example in figure 13, we focus on the target verb "makes". "Sequa" and "engines" are closed related to the verb, serving as subject and object respectively, or agent and patient according to the quoted task definition. The model utilize the syntactic dependency tree, which is represented as a graph, put on top of the LSTM output, to boost the performance of LSTM-only model.
- model
- syntactic dependency trees represented as graph
- word representations are passed through BiLSTM to get initial embeddings for GNN
aggregation function with gates
g_{u, v}^{(k)} = \sigma \left( h_u^{(k)} \cdot \hat{v}_{dir(u,v)}^{(k)} + \hat{b}_{L(u,v)}^{(k)} \right) \\ h_v^{(k+1)}= ReLU \left( \sum_{u \in \mathcal{N}(v)} g_{v, u}^{(k)} \left( W_{dir(u,v)}^{(k)} h_u^{(k)} + b_{L(u, v)}^{(k)} \right) \right) - prediction: p(r | t_i, t_p, l) \propto \exp(W_{l, r} (t_i \circ t_p) )
Results
- we examine two models here, namely LSTM+GNN and GNN-only. For LSTM+GNN model, we use only 4% edges achieving a F1 score which is only 0.62% less than before. And for GNN-only model, we sacrifice 0.79% F1 score to drop 84% of all edges. As in question answering task, we also continue dropping the remaining edges gradually, which shows the same effect, i.e. model performance decreases dramatically, which can be seen in figure 16;
- in the aggregation function of GNN model, each edge is weighted with a gate value, which can be regarded as contribution value in the interpretation method intuitively. But it turns out that it is not always the case. The average gate value of the trained model is 0.16 and the deviation is 0.07. When we drop all the edges with a gate value smaller than (\mu - \sigma), i.e. 42% of the least important edges from the viewpoint of gate values, the prediction F1 score declines 16.1%. But from the above result we can see that actually we can remove even more edges without affecting model performance. This comparison reveals that edges with low gate values cannot be ignored and some edges with high gate values are actually superfluous. In other words, weights are not necessarily attribution score. This finding proves the necessity of proposing a faithful interpretation method to find faithful edges;
- it turns out that nominal predicate and verbal predicate show different behaviors. Most of the retained edges going out from nominal predicates lead to NMOD (nominal modifier) while the destinations for verbal predicates are SBJ and OBJ. This is not surprising because those are directly related according to syntax. Even if the retained edge is not directed linked to the predicate, it also connects the token that is close to the predicate, as shown in figure 14, "rebound" links to "is", which is next to the predicate "expected";
- if we split the paths connecting predicate and arguments according to path length and predicate type, we have figure 17. For those paths connecting predicate and arguments directly, i.e. length equals 1, almost all the edges are retained, as these edges indicate their semantic roles immediately;
- These observations are consistent with our commen sense and linguistic syntax, which can verify the faithfulness of the proposed interpretation method.
Figure 16: performance when continue dropping retained edges gradually
Figure 17: percentages of retained paths connecting predicate and arguments, split by path length and predicate type
Discussion
Conclusion
This paper introduces GraphMask, a post-hoc interpretation method applied to trained GNN model, by dropping superfluous edges. The sparsity is enforced by L_0 norm and the training is accomplished with the help of sparse relaxation and amortizing over the training data. An artificial dataset shows GraphMask is faithful comparing to other methods. The application on a QA model and an SRL model uncovers which edge types these model rely on and how they employ paths when making predictions.
Personal Review
From my point of view, the paper is written under a smooth logic, the proposed model GraphMask is elegent, experiments are informative, despite of some drawbacks.
- The paper clearly clarify which problem they want to solve and how they come up with the idea, why erasure search doesn't work and how they solve its issues, and how they modify the details of the model step-by-step to enable differentiable training, where they compare many potential alternatives and choose sparse relaxation at last. This procedure could give you some hints about how to come up with your own idea.
- GraphMask is a simple, straightforward, yet effective idea: just a classifier on top of each message predicting if it should be dropped or not. It can be applied to any GNN model due to the model-agnostic characteristic. What's more, it shows its superiority than other models by dealing with hindsight bias.
- I like the synthetic experiment very much, though it is not applied to real scenarios. With a simple task, it intuitively shows that GraphMask performs the best against a bunch of other interpretation methods, with respect to the idea of faithfulness. In real NLP task like question answering and semantic role labeling, instead of just saying how many edges they have dropped and how close the performance is, they also conduct many supplementary experiments to get a deeper idea into the model mechanism. For example, they reveal that the reasoning paths in the question answering model start from COMPLEMENT, followed by MATCH/DOC-BASED. In SRL model, those edges connecting predict and arguments directly are retained for prediction. These further illustrations of the behavior from masked GNN are consist with our commen sense, which can prove the faithfulness of GraphMask again.
- However, some small parts of these experiments explanation are quite messy and not arranged in a decent way. I need to read the sentences several times with the corresponding figure so as to guess a relatively correct understanding, especially at the last part in SRL task.
- In my opinion, the background knowledge of question answering and semantic role labeling is less covered in the paper, especially for semantic role labeling. So I have to go through different papers to get an idea of what they really want to do.
- Also the mathematical part related to optimization is not illustrated in an intuitive way, which confuses me a lot, until I read through the several related paper about L_0 norm optimization and constraint relaxation. But I have to admit that this procedure broadens my knowledge in the optimization field.
- The last comment is about the idea of interpretation, where I am also a new guy. I choose this paper because it combines NLP and GNN, which are closed related to my experience. As we all know, neural networks are still regarded as a black-box. I am a bit doubtful if it is reasonable to use some black-box to explain another black-box. Or maybe this idea is too mean, since this method do reveal some insight into GNN models. Also I would like to see some inherently built-in interpretation methods in the future.
Reference
- Interpreting Graph Neural Networks for NLP with Differentiable Edge Masking, Schlichtkrull et al.
- Question Answering by Reasoning Across Documents with Graph Convolutional Networks, Cao et al.
- Encoding Sentences with Graph Convolutional Networks for Semantic Role Labeling, Marcheggiani et al.
- Understanding Nerual Networks through Representation Erasure, Li et al.
- Learning Sparse Neural Networks through L_0 Regularization, Louizos et al.
- The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables, Maddison et al.
- Axiomatic Attribution for Deep Networks, Sundararajan et al.
- Restricting the Flow: Information Bottlenecks for Attribution, Schulz et al.
- GNNExplainer: Generating Explanations for Graph Neural Networks, Ying et al.