Introduction
Motivation
Graphs offer a powerful way to structure connected data and explore relationships between data. With data represented as graphs, various methods and algorithms can be used to perform analysis and gain insights. This opens up various fields of application for graphs in social networks [1], molecule classification [2, 13], or analysis of interactions between proteins [2, 3].
In recent years, there has been significant interest in the use of graph neural networks in graph classification, in which node embeddings are transformed and aggregated to compute a single graph embedding, which can then be used in any prediction layer for classification. However, existing methods for aggregating node embeddings are flat, so they do not take advantage of the natural hierarchical structures of graphs in the pooling phase. In this paper, the authors present an approach to address this problem using learnable clustering to iteratively pool the input graph.
Related Work
Graph Neural Networks are used for link prediction, node classification and graph classification tasks. The approach in all three tasks is usually very similar. First, the GNN node embeddings are generated, using the message parsing framework, which iteratively computes the node representation with the node features of neighboring nodes and a differentiable aggregation function. Then the nodes are composed into an overall representation of the graph. Several approaches already exist for this aggregation of node embeddings:
- Global sum/mean/max pooling: Simply sums up / takes the average / takes the maximum of all node embeddings.
- SortPool [11]: Sorts the embeddings in ascending order and selects top-k nodes with k as hyperparameter.
- Set2Set [12]: Global pooling operator based on iterative content-based attention.
Graph classification methods even before GNNs have existed in the literature for a long time. Predominant for a long time were kernel-based methods, which utilize the kernel trick to efficiently perform classifications using an SVM. In kernel-based approaches, there are differences in the way graph similarity is defined. The most prominent methods for this are Shortest-Path [8], Weisfeiler-Lehman [9], or Weisfeiler-Lehman Optimal Assignment Kernel [7].
Paper Contributions
The authors propose the following:
- A novel differentiable graph pooling module that can generate hierarchical representations of graphs and can be combined with various graph neural network architectures in an end-to-end fashion (DiffPool).
- Learning hierarchical structures in graphs with GNNs in contrast to finding clusters by deterministic algorithms.
Methodology
Graph Classification
We denote a graph G as (A,F) where A \in {0,1}^{n \times n} is the adjacency matrix and F \in \mathbb{R}^{n \times d} are the node features. In the graph classification problem we are given a set of labeled graphs D = {(G_1,y_1),(G_2,y_2),...} with y_i \in \mathcal{Y} being the label to graph G_i \in \mathcal{G} and try to learn the corresponding mapping f: \mathcal{G} \rightarrow \mathcal{Y}.
Global Pooling
The challenge of using GNNs in graph classification compared to node classification or link prediction is to move from node embeddings to a representation of the entire graph in the form of a graph embedding, which can then be used as input in standard machine learning methods for classification. As shown in Figure 1, a common approach is to first compute node embeddings using well-known GNN models (e.g. GraphSAGE) and then aggregate the node embeddings with a global pooling layer into a graph embedding. In the last step, linear models may be used to perform the classification.
Figure 1: Pipeline for graph classification with GNNs. A GNN is used to generate the node embeddings, and then aggregated into a graph embedding using a global pooling function. In the end, classical ML methods will be used for the classification.
A simple implementation of the graph pooling layer is to sum or average all node embeddings. A more advanced method for global pooling is SortPool, which sorts the node features in ascending order along the feature dimension and selects the sorted features of the top K nodes. A disadvantage of these approaches is that they are not differentiable and learnable. Therefore Set2Set proposes a learnable LSTM to aggregate the node features as a set.
When considering the methods mentioned above (Sum/Avg-Pooling, SortPool, Set2Set), the authors of the DiffPool paper criticize that the graph structure is no longer captured in the global pooling phase. Hierarchical cluster structures can therefore not be included in the classification in a meaningful way.
Differentiable Pooling via Learned Assignments
Figure 2: Visualization of DiffPool. At each layer, a GNN is run to obtain the embeddings and each node is assigned to a cluster. Clustered nodes are then collapsed into a single node with an aggregated cluster embedding.
The idea of DiffPool is to exploit the hierarchical structure of the graphs during the pooling phase in an end-to-end fashion by finding clusters that are then gradually merged. As shown in Figure 2, for each layer of the DiffPool model we identify similar nodes that can be merged without losing much structural information. For the merged nodes, we also compute new embeddings, called cluster embeddings, so that each cluster merges the nodes from the previous layer as accurately as possible. DiffPool uses two separate GNN models at each layer to determine cluster assignments and to generate new node embeddings. The order here is that DiffPool first calculates new embeddings over the entire graph, then determines suitable clusters for the nodes, and then merges the previously calculated embeddings into clusters. The computations in each layer of the DiffPool model are as follows:
To compute the new embeddings Z^{(l)} at layer l, we apply the embedding GNN:
Z^{(l)} = \text{GNN}_{l,\text{embed}}(A^{(l)},X^{(l)}) |
In order to form the clusters that we want to collapse, we need have to assign each node to a cluster. For this purpose, we define a cluster assignment matrix S^{(l)} for each layer l, which we determine via a learnable GNN model. The cluster assignment results are normalized using a softmax function such that each row of S^{(l)} defines the probability with which a node falls into each cluster:
S^{(l)} = \text{softmax}(\text{GNN}_{l,\text{pool}}(A^{(l)},X^{(l)})) |
With the assignment matrix S^{(l)} and new node embeddings Z^{(l)}, we are able to build the coarsened graph for the next layer l+1. We obtain the new cluster embeddings by applying the following formula:
X^{(l+1)} = {S^{(l)}}^T Z^{(l)} \in \mathbb{R}^{n_{l+1} \times d} |
We get the new weighted assignment matrix, which expresses the relationships between the clusters, by the following formula:
A^{(l+1)} = {S^{(l)}}^T A^{(l)} S^{(l)} \in \mathbb{R}^{n_{l+1} \times n_{l+1}} |
In DiffPool, both the architecture of the pooling GNN model and the embedding GNN model can be chosen freely. The authors prefer using GraphSAGE for both models, but also other approaches such as GCNs or GATs can be used. Typically, several of the DiffPool layers are stacked on top of each other with each layer shrinking the graph to a predefined percentage (e.g. 25%) of nodes in the previous layer, so that the graph is gradually coarsened.
Auxiliary Link Prediction Objective and Entropy Regularization
The authors noticed when training the DiffPool model that it is difficult to train the pooling GNN using only the gradient signals from the graph classification tasks. For this reason, they added an auxiliary link prediction objective to the model, following the intuition that nearby nodes should be assigned to the same clusters. The link prediction objective for each layer l is defined as:
L_{LP}=\|A^{(l)},S^{(l)}{S^{(l)}}^T\|_F |
where \|\cdot\|_F denotes the Frobenius norm.
Another problem with the vanilla DiffPool model is that it generates dense cluster assignment matrices, especially at the beginning, so nodes are often assigned to multiple clusters. Ideally, each node should be assigned to only one cluster, i.e. each row in the assignment matrix should be a one-hot encoded vector. The authors addressed this problem by adding a loss for each layer l to regularize the entropy for the cluster assignments:
L_E=\frac{1}{n}\sum_{i=1}^{n} H(S_i) |
with H denoting the entropy function and S_i the i-th row of the cluster assignment matrix.
Both link prediction loss and entropy loss from all layers are added at the end to the overall classification loss. The authors claim that these auxiliary targets have improved the accuracy and understandability of the DiffPool model.
Evaluation
Setup & Datasets
For the evaluation, the authors answered two questions:
- How does DiffPool compare to other graph classification methods such as GNN-based approaches or kernel-based approaches?
- Does DiffPool provide meaningful and interpretable clusters in all layers of the model?
To answer the first question, the authors trained the model on 5 different datasets. The datasets used were the well-known medical datasets for protein prediction ENZYMES, PROTEINS [3, 5], and D&D [4], as well as a social network dataset REDDIT-MULTI-12K [6] and a scientific collaboration dataset COLLAB [6].
For their own model DiffPool, they used a total of 2 DiffPool layers, with 3 GraphSAGE layers using the "mean" aggregator type between the DiffPool layers. For the two smaller datasets, ENZYMES and COLLAB, only one DiffPool layer was sufficient to achieve similar results to the larger model. The authors also experimented with several variants of their DiffPool model. DiffPool-Det is a variant in which the cluster assignment matrices are calculated using a deterministic clustering algorithm. In DiffPool-NoLP the link prediction side objective is switched off. Hyperparameter optimization was performed for DiffPool. The authors have published the code they used for the DiffPool model on the benchmarks on Github.
Baseline methods
The authors compared their model against 4 kernel-based, as well as 5 other GNN methods. The kernel-based algorithms included Graphlet [10], Shortest-Path [8], and Weisfeiler-Lehman (1-WL) [9], as well as the state-of-the-art kernel baseline Weisfeiler-Lehman Optimal Assignment (WL-OA) [7]. The authors used the C-SVM implementation of LibSVM for running the kernel-based benchmarks. For the GNN methods, the authors chose GraphSAGE as the base GNN model and included PatchySan, ECC, Set2Set, and SortPool as pooling layers in their benchmarks. The authors also tried global max/min-pooling, but settled with mean-pooling as the best performing pooling layer out of the traditional global pooling approaches. The mean-pooling GNN model is denoted as GraphSAGE in their benchmarks.
Results
Table 1: Classification accuracies of DiffPool and baseline methods. The gain column shows the relative increase of DiffPool compared to GraphSAGE.
The results of the benchmarks performed by the authors are shown in Table 1. DiffPool achieves state-of-the-art performance with the highest accuracy in 4 of the 5 benchmarks. Compared to the GraphSAGE models, it achieves an average improvement in accuracy of 6.27%. For the COLLAB dataset, the simplified DiffPool-Det achieves higher accuracy than the DiffPool model. The authors state that many of the collaboration graphs in the dataset contain only single-layer community structures, which are better captured by the pre-computed graph algorithms.
Figure 3: Example visualization of cluster assignments in graphs from the COLLAB dataset. Figure (a) shows the hierarchical clustering over two layers, where the clusters from the first layer corresponding to the nodes in the second layer. Figures (b) and (c) show two more first-layer cluster assignments from other graphs.
The authors also investigated the interpretability of the clusters by examining the cluster assignments from the COLLAB dataset at different layers. Figure 3 shows selected examples of the cluster assignments, with the color of the node indicating its cluster membership. The membership of a node to a cluster was determined based on the argmax of its cluster assignment probabilities. In their analysis of the clusters, the authors made the following observations:
- Hierarchical cluster structure: Even when cluster assignments are based solely on the graph classification objective, DiffPool can still capture the hierarchical structure of the community. The link prediction auxiliary objective significantly improves membership assignment quality.
- Dense vs. sparse subgraph structure: DiffPool has a tendency to collapse densely-connected subgraphs into clusters. Clique-like subgraphs are easily captured and collapsed. However sparse subgraphs, such as path-, cycle- and tree-like structures are often not captured by the GNN pooling approach.
- Assignment for nodes with similar representations: Nodes with both similar input features and neighbors will always have the same cluster assignment, even if they are far away from each other.
- Sensitivity of the Pre-defined Maximum Number of Clusters: The assignment quality depends on the depth of the network and the number of clusters defined at each layer. The pooling GNN is able to capture more complex hierarchical structures, when a higher number of clusters is selected, however, the overall result contains more noise and is less efficient.
Conclusion
In this paper, the authors introduced DiffPool, a differentiable global pooling method that can learn the hierarchical structure of graphs and use it for classification. In combination with existing GNN methods, DiffPool achieves state-of-the-art performance in numerous benchmarks compared to both GNN-based and kernel-based methods. The authors leave open the question of how large the impact of cluster assignment quality is and whether this aspect can be further improved by later research.
Personal Review
The authors present a fundamentally new approach to graph pooling and lay a foundation for subsequent work building on it. The methods and formulas used are explained both mathematically and visually in an understandable way. The new method is evaluated in detail in standardized benchmarks, using well-known graph classification datasets from medical and social settings. The reference implementation in PyTorch published by the authors also makes it easy to replicate the benchmarks.
For my own understanding and demonstration purposes, I have implemented the vanilla variant of Diffpool, i.e. without Auxiliary Link Prediction Objective and Entropy Regularization using DGL. I followed the implementation of the authors and also the two well-known geomantic deep learning frameworks DGL and Pytorch Geometric offer tutorials on the topic. There were no major drawbacks in the implementation and all matrix multiplications and the integration into the overall GNN model are very easy to follow, which speaks for the applicability of the paper. The source code of my implementation can be found on GitHub: https://github.com/d-stoll/diffpool
Even though the paper is very good in terms of content and writing, it also has a few shortcomings. Negative aspects of the new method, such as significantly higher memory consumption and greater model parameters, are unfortunately not mentioned in the paper. The abstract and also the introduction does not immediately delineate that the graph classification problem is addressed. Here, the authors could avoid possible confusion by more precise formulations. Also, the title and the chosen abbreviation of the method (DiffPool) might give the impression that this approach is the first to use differentiable methods for graph pooling, although other papers have used differentiable methods before.
References
- Hamilton, W. L., Ying, R., & Leskovec, J. (2017). Inductive Representation Learning on Large Graphs. arXiv. https://doi.org/10.48550/ARXIV.1706.02216
- Dai, H., Dai, B., & Song, L. (2016). Discriminative Embeddings of Latent Variable Models for Structured Data. arXiv. https://doi.org/10.48550/ARXIV.1603.05629
- Borgwardt, K. M., Ong, C. S., Schönauer, S., Vishwanathan, S. V., Smola, A. J., & Kriegel, H. P. (2005). Protein function prediction via graph kernels. Bioinformatics (Oxford, England), 21 Suppl 1, i47–i56. https://doi.org/10.1093/bioinformatics/bti1007
- Dobson, P. D., & Doig, A. J. (2003). Distinguishing enzyme structures from non-enzymes without alignments. Journal of molecular biology, 330(4), 771–783. https://doi.org/10.1016/s0022-2836(03)00628-4
- Feragen, A., Kasenburg, N., Petersen, J., de Bruijne, M., & Borgwardt, K. (2013). Scalable kernels for graphs with continuous attributes. In C. J. Burges, L. Bottou, M. Welling, Z. Ghahramani, & K. Q. Weinberger (Eds.), Advances in Neural Information Processing Systems (Vol. 26). Curran Associates, Inc. https://proceedings.neurips.cc/paper/2013/file/a2557a7b2e94197ff767970b67041697-Paper.pdf
- Yanardag, P., & Vishwanathan, S. V. N. (2015). A Structural Smoothing Framework For Robust Graph Comparison. In C. Cortes, N. Lawrence, D. Lee, M. Sugiyama, & R. Garnett (Eds.), Advances in Neural Information Processing Systems (Vol. 28). Curran Associates, Inc. https://proceedings.neurips.cc/paper/2015/file/7810ccd41bf26faaa2c4e1f20db70a71-Paper.pdf
- Kriege, N. M., Giscard, P.-L., & Wilson, R. C. (2016). On Valid Optimal Assignment Kernels and Applications to Graph Classification. arXiv. https://doi.org/10.48550/ARXIV.1606.01141
- Borgwardt, K. M., & Kriegel, H. P. (2005). Shortest-path kernels on graphs. Fifth IEEE International Conference on Data Mining (ICDM’05), 8 pp.-. https://doi.org/10.1109/ICDM.2005.132
- Togninalli, M., Ghisu, E., Llinares-López, F., Rieck, B., & Borgwardt, K. (2019). Wasserstein Weisfeiler-Lehman Graph Kernels. arXiv. https://doi.org/10.48550/ARXIV.1906.01277
- Sherashidze, N., Vishwanathan, S. V. N., Petri, T., Mehlhorn, K., & Borgwardt, K. (2009). Efficient Graphlet Kernels for Large Graph Comparison. 12th International Conference on Artificial Intelligence and Statistics (AISTATS), Society for Artificial Intelligence and Statistics, 488-495 (2009) 5.
- Vinyals, O., Bengio, S., & Kudlur, M. (2015). Order Matters: Sequence to sequence for sets. arXiv. https://doi.org/10.48550/ARXIV.1511.06391
- Zhang, M., Cui, Z., Neumann, M., & Chen, Y. (2018). An End-to-End Deep Learning Architecture for Graph Classification. Proceedings of the Thirty-Second AAAI Conference on Artificial Intelligence and Thirtieth Innovative Applications of Artificial Intelligence Conference and Eighth AAAI Symposium on Educational Advances in Artificial Intelligence.
- Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017). Neural Message Passing for Quantum Chemistry. arXiv. https://doi.org/10.48550/ARXIV.1704.01212