This is the blogpost for the paper "Latent-Graph Learning for Disease Prediction" written by Luca Cosmo, Anees Kazi, Seyed-Ahmad Ahmadi, Nassir Navab, and Michael Bronstein.
Introduction
Recently, geometric deep learning and graph convolutional network (GCNs) have been introduced as a novel framework for computer-aided disease classification and diagnosis (CADx) [1]. GCNs have already been applied to multitude of problems in the medical domain: neurodevelopmental outcome prediction [2], brain analysis in large populations [3], protein interaction prediction [4], metric learning on brain connectivity networks [5], or representation learning for medical images [6]. In this work, the focus is put on CADx, disease prediction from multimodal patient data within population.
As such, we need to build a population model, and graphs provide a natural way of modelling populations and their similarities [7]. In such setting, vertices represent the patients, edges represent connections between the patients, and graph adjacency matrix represents pair-wise patient similarities. The task is to learn a set of filters with optimal weights for CADx classification by using aggregated patient features over local neighborhoods in the graph. The neighborhoods provide information between the patients in the population.
Related work
To compute the graph adjacency matrix, we need to define the similarity metric between the patients. Kipf and Welling first proposed semi-supervised node classification using spectral GCNs [15], while Parisot et al. adapted it to CADx. They proposed calculating patient similarities from meta-features (age, sex, demographic features) [3]. GCNs are sensitive to the graph structure, thus metric definition needs to be carefully tuned.
Several papers propose building multiple graphs separately from patient data, such that each graph is representing a different patient feature [9]. While the approach with multiple graphs achieves high accuracy and robustness, it also has two main problems:
- high number of parameters limits the scalability of GCNs
- fusion layer in such networks has to be carefully designed
How can we alleviate these problems? This paper proposes learning the graph (adjacency matrix) end-to-end.
Learning a graph can benefit other applications as well, especially in terms of interpretability. Methods on learning a graph proposed so far differ in nature. Zhan et al. propose constructing multiple graph Laplacians and jointly optimizing data correlation from multiple features in an adaptive way [16]. Franceschi et al. propose learning a discrete probability distribution on the edges of the graph via reparameterization trick and optimizing the graph structure via hypergradient descent [17]. Jang et al. propose a model for EEG classification, which can extract an appropriate multi-layer graph structure (for brain connectivity) and signal features directly from a set of raw EEG signals [18].
However, in most works so far graph has to be computed beforehand, which has the following problems:
- graph construction requires defining a similarity metric (computed by a fixed function)
- each single-feature graph is computed independently and combined for learning (whereby the correlation between the features is ignored)
Contributions
In this paper, a graph convolutional network is proposed, that can classify patients within a population, with the following properties:
- underlying patient graph structure is learned by the model, optimal for the disease prediction task
- a single graph is used and learned end-to-end
- the method is inductive, making it possible to introduce previously unseen patients to the current population
Methodology
There are several advantages of the proposed method that uses a single end-to-end learned graph (as opposed to multiple graphs):
- higher classification accuracy
- reduced network complexity since the node input features are embedded into a lower-dimensional Euclidean space, meanwhile solving the scalability problem
- better patient representation, by not restricting only to basic patient meta-features
Latent-Graph Learning
Firstly, the architecture of the model for graph learning is explained. Its goal is finding underlying patient graph structure, optimal for the disease prediction task. To enable training the graph end-to-end, the adjacency matrix is defined to be real-valued, i.e. \mathbf{A} \in [0, 1]^{N x N}, making the convolution layers output differentiable with regards to the graph structure. A function \tilde x_i = f_{\phi}(\mathbf{x_i}) is learned, embedding the input features into a lower-dimensional Euclidean space. A simple Multilayer Perceptron (MLP) is used as the function f_{\phi}. The edge weight connecting nodes i and j is then defined as:
a_{ij} = \frac {1}{1 + e^{ -t ({\lVert \mathbf{\tilde x_i - \tilde x_j} \rVert}_2+ \theta)}} |
with the following learnable parameters:
\theta, a soft threshold parameter applied to distances between features
- t, a temperature parameter, which pushes the values of a_{ij} to 0 or 1
This setting allows the update of each edge weight with respect to the loss. The outcome of the graph learning module provides the population graph model to the GCN.
Classification Model
The previously described latent-graph learning module is used within the patient classification model, i.e. to predict the labels y_i, i = 1 ... Nof all the N patients, with associated patient features \mathbf{X} \in \mathbb{R}^{N x d_1}. The model is built from several consecutive graph convolutional layers, ending with a fully connected layer for patient classification.
The output of the spatial graph convolution layer is defined as:
\mathbf{H_{l + 1}} = \sigma (\mathbf{D}^{-1} \mathbf{A} \mathbf{H}_l \mathbf{W} ) |
with:
- \sigma, non-linear activation function,
- D, normalization matrix with d_{ii} = \sum_{ij}{a_{ij} } ,
- \mathbf{A}, graph built from the graph learning module and feature vectors X,
- \mathbf{H}_l, output of the previous layer, and
- W, model filters to be learned.
Figure 1 shows the graph learning pipeline.
Figure 1: Latent-graph learning pipeline [1]
Experiments & Results
Proof of concept
Firstly, the graph learning module is tested within an experimental setting. A graph G with N nodes is generated randomly. The feature matrix is initialized with identity matrix, i.e. \mathbf{X} = \mathbf{I} \in \mathbb{R}^{N x N}, thus minimally impacting the task. The ground-truth vector of each node i is defined as the sum of its neighbours' features, i.e. y_i = \sum_{j \in \mathcal{N}} x_{ij}.
The task is then to predict target vector y_i, i.e. to optimize:
arg \underset{\phi}{\min} \sum_i (\mathbf{y}_i - \mathbf{A}(\mathbf{\phi}, \mathbf{X})\mathbf{X}_{:i})^2 |
with:
- \mathbf{A}(\mathbf{\phi}, \mathbf{X}), latent-graph learning module,
- \mathbf{X}, patient features, and
- \mathbf{\phi}, optimized parameters.
Figure 2 shows the optimization problem results. The proposed graph learning technique predicts the ground-truth labels with high accuracy. Furthermore, it is shown that the dimension of the embedding has an influence on finding the underlying graph structure - the error gets smaller with the growing size of the embedding.
Figure 2: Left: comparison of ground-truth graphs and graphs learned by the proposed model. Right: error (MSE) between
the ground-truth and predicted labels with regards to the number of nodes, tested with different sizes of the embedding [1]
Comparisons
In the following paragraphs, comparisions with state-of-the-art methods are shown. The results are analyzed on two public datasets from medicine domain, TADPOLE and UKBB, which are chosen because of their size difference, thus helping to see how the proposed method adapts to different datasets. TADPOLE [10] includes data from 564 patients with 354 features, and the task is to predict whether the patient is cognitively healthy, impaired, or has Alzheimer’s Disease. UKBB's [11] subset on 14.503 patients is used and includes 440 features of the brain extracted from MRI images, with the task to predict the patient's age. As error metrics, accuracy and AUC for multi-class setting are used.
Comparison with baseline methods
Table 1 shows the comparison of the proposed method with the following baseline methods:
- ridge regression classifier as a linear classifier
- Spectral-GCN based on graph convolutions in the Fourier domain
- Dynamic graph CNN (DGCNN) which aims to construct a local neighbourhood graph (k-NN) and applying convolution-like operations on the edges between the neighbourhood set of points [16]
The novel method proposed in the paper outperforms all the baselines by 8.32% for TADPOLE, and 6% for UKBB. Furthermore, spectral GCNs are not that memory efficient [12], which is why they could have not been tested with the UKBB dataset. Spectral-GCN also obtained a worse result than the DGCNN, so having pre-defined graphs instead of learning them is likely not optimal.
Table 1: Comparison with baseline methods on TADPOLE and UKBB [1]
Comparison with state-of-the-art methods
Table 2 depicts how the proposed method compares with state-of-the-art methods. The following state-of-the-art methods are considered:
- Multi-GCN, GCN multiple-graph model that incorporates the information of each graph separately [9]
- InceptionGCN, GCN method on TADPOLE [13]
- DGM (Differentiable Graph Module), a learnable function predicting the edge probability in the graph [8]
The proposed method outperforms all the other methods. The number of parameters required is also lower than for other models. Furthermore, we can see the standard deviation is low, proving the robustness of the model.
Table 2: Comparison with state-of-the-art methods on TADPOLE and UKBB [1]
Comparison with inductive methods
Table 3 presents a comparison with other inductive graph methods, DGCNN and DGM. The proposed method is also inductive, i.e. the model learns to predict the graph structure from the input features of the patients. Therefore, we can add new patients during testing, which will be embedded into a lower-dimensional Euclidean space. The proposed model surpasses the state-of-the-art, and even with a lower standard deviation, further proving the robustness of the model.
Table 3: Comparison with inductive methods on TADPOLE and UKBB [1]
Qualitative Results
To evaluate the graph the model has learned, an approximative reference graph for TADPOLE is generated as a weighted sum of single graphs based on the approach from Multi-GCN [9]. Figure 4 depicts that the overall structure of the learned graph and reference ground-truth graph is rather similar. Domain experts should further investigate the given results.
Figure 3: Left: Ground-truth graph. Right: graph learned by the model [1]
Implementation Details
The architecture of the GCN includes two convolutional layers (16 → 8) followed by a MLP (32 → 16 → number of classes) that gives the classification prediction. ReLU is used as a non-linear activation function. Dropout rate of 0.9 is used after all layers except the last one. Furthermore, recursive feature elimination is used to reduce the dimensionality of feature spaces (354 → 30 for TADPOLE and 440 → 200 for UKBB), with normalization applied. For training, Adam is used as optimization algorithm with initial learning rate of 0.01, reduced to 0.0001 every 100 epochs. Training is finished after 600 epochs. The model is implemented in TensorFlow.
Discussion
In the paper, a graph learning model capable of learning the underlying graph structure towards the CADx task of patient classification is proposed. As such, the graph adjacency matrix does not have to be pre-defined. In all the experiments the model performed better than baseline and state-of-the-art methods, and also generalizes to the inductive setting. The model is trained end-to-end, using learnable parameters \phi as a soft thresholding technique, and temperature parameter.
In the proposed method, a single global threshold for the entire population is learned, which might not take into account that the embedding structure might be heterogeneous. As future work, learning a neighborhood threshold for each patient (graph node) is proposed. Furthermore, there is promising work that shows why the appropriate isometric space for embedding complex networks is hyperbolic and not Euclidean space, so hyperbolic embedding spaces may be considered [14].
Student Review
The paper is really well-structured. It gives a good motivation of why is there a need to develop models that can learn the underlying structure of the population graph, such that the graph adjacencies do not need to be pre-defined. Most of the things are very well-explained, and I believe it is even possible for the non-computer science readers (e.g. from medicine, neurobiology, etc.) to grasp the main concepts. The results could also possibly be used in other fields, and not strictly in the area of CADx. The basic memory analysis of the method is also given. Finally, the results accomplished are significantly better than state-of-the-art and baselines, and the proposed method is novel and not only a derivative of the previous work.
Adding more visualizations, and especially the code, could improve the understanding of how exactly the graph learning module and the classification model work, and why was the soft thresholding technique chosen. Experiments could also be performed on the state-of-the-art methods from other authors, as the once that were reported on were mostly previously proposed by the same authors. Furthermore, comparison on time performance with baselines and other state-of-the-art methods could also be given.
As future work, I would be interested to see the results being analysed by experts from medical domain. Apart from learning a neighbourhood threshold for each patient (graph node) and introducing hyperbolic embedding spaces proposed by the authors, integration of self-supervised learning, few-shot learning, and directed graphs could also be considered.
Finally, interpretability could also be assessed by measuring the dissimilarity between the graph structures generated from different repetitions, applied to all pairwise combinations of repetitions [15]. Then, a low level of dissimilarity indicates the extracted graphs are reliable and meaningful. A distance function is applied to the adjacency matrices of the two graphs, and the distances for all pairs of repetitions are averaged, and divided by the total number of possible edges. This method also proves the robustness to graph perturbations and adversarial attacks.
References
[1] Luca Cosmo, Anees Kazi, Seyed-Ahmad Ahmadi, Nassir Navab, Michael Bronstein. Latent-Graph Learning for Disease Prediction. In International Conference on Medical Image Computing and Computer-Assisted Intervention, 2020
[2] J. Kawahara, C. J. Brown, S. P. Miller, Brian G. Booth, V. Chau, R. Grunau, J. Zwicker, G. Hamarneh. Convolutional neural networks for brain networks; towards predicting neurodevelopment. In NeuroImage, 2017
[3] Sarah Parisot, Sofia Ira Ktena, Enzo Ferrante, Matthew Lee, Ricardo Guerrero, Ben Glocker, Daniel Rueckert. Disease prediction using graph convolutional networks: Application to Autism Spectrum Disorder and Alzheimer's disease. In Medical Image Analysis, 2018
[4] P. Gainza, F. Sverrisson, F. Monti, E. Rodolà, D. Boscaini, M.M. Bronstein, B.E. Correia. Deciphering interaction fingerprints from protein molecular surfaces using geometric deep learning. In Nat. Methods, 2019
[5] Sofia Ira Ktena, Sarah Parisot, Enzo Ferrante, Martin Rajchl, Matthew Lee, Ben Glocker, Daniel Rueckert. Metric learning with spectral graph convolutions on brain connectivity networks. In NeuroImage, 2018
[6] Hendrik Burwinkel, Anees Kazi, Gerome Vivar, Shadi Albarqouni, Guillaume Zahnd, Nassir Navab, Seyed-Ahmad Ahmadi. Adaptive Image-Feature Learning for Disease Classification Using Inductive Graph Networks. In MICCAI, 2019
[7] Sarah Parisot, Sofia Ira Ktena, Enzo Ferrante, Matthew Lee, Ricardo Guerrerro Moreno, Ben Glocker, Daniel Rueckert. Spectral Graph Convolutions for Population-based Disease Prediction. In MICCAI, 2017
[8] Anees Kazi, Luca Cosmo, Nassir Navab, Michael Bronstein. Differentiable graph module (DGM) for graph convolutional networks. arXiv preprint arXiv:2002.04999 (2020)
[9] Anees Kazi, S.Arvind krishna, Shayan Shekarforoush, Karsten Kortuem, Shadi Albarqouni, Nassir Navab. Self-attention equipped graph convolutions for disease prediction. In 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019), 2019
[10] Razvan V. Marinescu et al. TADPOLE Challenge: Prediction of Longitudinal Evolution in Alzheimer's Disease. arXiv preprint arXiv:1805.03909 (2018)
[11] K.L. Miller et al. Multimodal population brain imaging in the UK biobank prospective epidemiological study. In Nat. Neurosci 19(11), 1523 (2016)
[12] Damitha Senevirathne, Isuru Wijesiri, Suchitha Dehigaspitiya, Miyuru Dayarathna, Sanath Jayasena, Toyotaro Suzumura. Memory Efficient Graph Convolutional Network based Distributed Link Prediction. Memory Efficient Graph Convolutional Network based Distributed Link Prediction. In 2020 IEEE International Conference on Big Data (Big Data), 2020
[13] Anees Kazi et al. InceptionGCN: receptive field aware graph convolutional network for disease prediction. In IPMI 2019. LNCS, vol. 11492, 2019
[14] Benjamin Paul Chamberlain, James Clough, Marc Peter Deisenroth. Neural Embeddings of Graphs in Hyperbolic Space. In 13th international workshop on mining and learning from graphs held in conjunction with KDD, 2017
[15] Thomas N. Kipf, Max Welling. Semi-supervised classification with graph convolutional networks (2016). arXiv preprint arXiv:1609.02907 (2016)
[16] Kun Zhan et al. Adaptive structure discovery for multimedia analysis using multiple features. In IEEE Trans. Cybern. 49(5), 1826–1834 (2019)
[17] Luca Franceschi, Mathias Niepert, Massimiliano Pontil, Xiao He. Learning discrete structures for graph neural networks. In Proceedings of International Conference Machine Learning (ICML), 2019
[18] Soobeom Jang, Seong-Eun Moon, Jong-Seok Lee. Brain signal classification via learning connectivity structure. In arXiv abs/1905.11678 (2019)