Introduction
Generalization lies at the heart of all research in geometric deep learning. After all, the whole field stems from the goal of generalizing Convolutional Neural Networks, that natively perform on structured regular grids, to work in non-euclidean space, e.g. on graphs or manifolds [1]. But with the development of the first Graph Neural Networks (GNNs), this process of generalization was not yet finished. Most methods were leaving one important aspect out of consideration – time. However, one must only look at common real-life data sources to see that time has a great influence on many of them. From financial use cases to social networks and also patient graphs, time often plays a major role in understanding the data and explicit modeling of these effects can only be of benefit [7]. To this end, Emanuele Rossi, Ben Chamberlain, Fabrizio Frasca, Davide Eynard, Federico Monti, and Michael Bronstein of the graph learning research group at Twitter introduced Temporal Graph Networks (TGNs) in their work "Temporal Graph Networks For Deep Learning on Dynamic Graphs" [2].
TGNs are a generic inductive framework for graph deep learning on continuous-time dynamic graphs, that generalize many previous methods, both on static and dynamic graphs. They employ a notion of memory to let the model remember long-term information and generate up-to-date node embeddings regardless of the age of that information. At the same time, they show more efficient computation than comparable methods while still achieving state-of-the-art results.
Dynamic Graphs
Most common GNN models perform on static graphs. This is just the simple form of a graph we are all familiar with. A Graph \mathcal{G} = (\mathcal{V}, \mathcal{E}) consists of nodes \mathcal{V} = \{1,\dots, n\} and edges \mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}. As the name 'static' implies, this graph stays constant during training and inference (except for nodes and corresponding edges being added at inference in the inductive setting).
In contrast to static graphs, there are time-dependent dynamic graphs that can change with time. Dynamic graphs exist in two main modes: Discrete-time dynamic graphs (DTDG) and continuous-time dynamic graphs (CTDG). While DTDGs are simply snapshots of static graphs at specific time intervals, CTDGs are a bit more complex. In the following, the CTDG is not defined by its edges and vertices, but by a sequence of events ordered by their time-stamps. Therefore, the temporal graph is \mathcal{G} = \{x(t_1), x(t_2), \dots\}, where x(t)is an event of one of the two following types:
- Node-wise events: This is simply characterized by \mathbf{v}_i(t), the feature vector of node i at time t. If node i was previously non-existent, this event adds the node to the graph. If it already exists, this event only updates the representation of i.
- Interaction events: This event stands for a directed edge between two nodes i and j at time t. It is again characterized by the edge's feature vector \mathbf{e}_{ij}(t).
As an example consider the sequence of events \{\mathbf{v}_1(0), \mathbf{v}_2(0.5), \mathbf{e}_{12}(0.9), \mathbf{v}_3(1.4), \mathbf{v}_2(1.9), \mathbf{e}_{12}(2.5), \mathbf{e}_{23}(2.5)\}. This sequence defines the time-dependent graph shown in Fig. 1.
Fig. 1: Example of a dynamic graph.
Temporal Graph Networks
For dealing with dynamic graphs, Kazemi et al. introduce the encoder-decoder framework [3]. They propose that all models can be divided into an encoder, that takes the dynamic graph as input and outputs the hidden representations of nodes or edges, and a decoder that maps these embeddings to predictions, e.g. edge prediction. Following this framework, TGN only covers the encoder part, i.e. it produces node embeddings based on the graph. Any decoder, e.g. a simple MLP, can be plugged into this solution to get a functional model.
A general overview of the model is shown in Fig.2. For each epoch, the training goes through all events in the training set in chronological order. Events can be collected into batches as long as all events in one batch happen before all events in its successor. The model processes these batches through five distinct modules. Given a batch of events, for each one of them, messages are computed with the Message Function that are sent to the nodes involved in this event. If multiple messages are sent to the same node, these messages are aggregated to a single message in the Message Aggregator. Based on the aggregated messages for this batch, the Memory of the involved nodes, that holds their long-term information, is updated in the Memory Updater module. Finally, the Embedding module generates the temporal embedding of the nodes that can then be passed to the decoder for any downstream task. All modules will be explained in more detail in the following sections.
Fig.2: General overview of the flow of information in the TGN model architecture.
Modules
Memory
At its core, the memory of node i simply holds all information of previous interactions and events up to time t compressed in a single vector \mathbf{s}_i(t), very similar to the cell state vector in common Recurrent Neural Networks like LSTMs [4]. When a new node is created, the memory gets initialized with a zero vector and after that, it is updated with every event involving this node, however, it is important to note, that in contrast to the other modules, the memory is not a learnable parameter itself and is thus also updated at test time, allowing the model to work with unknown nodes in an inductive setting.
Message Function
The message function computes messages for each event inside the batch. In case of a node-wise event \mathbf{v}_i(t) at time t involving node i, the message \mathbf{m}_i(t) = \text{msg}_n(\mathbf{s}_i(t^-), t, \mathbf{v}_i(t)) is generated. Analogously, for an interaction event \mathbf{e}_{ij}(t) between source node i and target node j, the messages \mathbf{m}_i(t) = \text{msg}_s(\mathbf{s}_i(t^-), \mathbf{s}_j(t^-), \Delta t(i), \mathbf{e}_{ij}(t)) and \mathbf{m}_j(t) = \text{msg}_t(\mathbf{s}_j(t^-), \mathbf{s}_i(t^-), \Delta t(j), \mathbf{e}_{ij}(t)) are computed. Here, \mathbf{s}_i(t^-) denotes the memory of node i at the time of its last interaction, \Delta t(i) is the time since the last event involving node i, and \text{msg}_n, \text{ msg}_s, and \text{msg}_t can be learnable function, e.g. MLPs. However, for reasons of simplicity the authors use the identity function, i.e. concatenation of the inputs, for their experiments.
Message Aggregator
As the events are processed in batches, it may occur that a single node i is involved in more than one event. In this case, all messages to node i are aggregated to a single message \bar{\textbf{m}}_i(t) = \text{agg}(\textbf{m}_i(t_1),\dots, \textbf{m}_i(t_b)) for t_1, \dots, t_b \leq t. Once again, the aggregation function could be learnable, however, for the same reasons as before, the authors choose a function that only keeps the most recent message.
Memory Updater
With the aggregated message of the batch and the current memory state, the memory of node i can now be updated to represent the new state after the events of the batch: \textbf{s}_i(t) = \text{mem}(\bar{\textbf{m}}_i(t), \textbf{s}_i(t^-)). Again, \text{mem}can be a learnable memory update function like LSTM or GRU [5], and indeed this time the authors choose the latter.
Embedding
The final module generates the temporal embedding \textbf{z}_i(t) of node i that can then be used for prediction purposes. One could simply use the memory state directly as the embedding, after all, it holds all the information of the node's history, however, it might happen that a node does not participate in any events for a long time. As the context of that node moves forward in time its representation might become out of date. This effect is called staleness. To remedy this problem, the embedding learns how the representation evolves with time, even though no observation of it is made. It does that by considering the node's neighbors, i.e. previous interaction partners. The authors suggest using L graph attention layers, to aggregate the L-hop neighborhood information. To explain this layer would be out of the scope of this blog post, but the interested reader is referred to [6] and [7] for an in-depth explanation.
Training
There is one major problem with the methodology thus far presented, that renders the training of the whole model useless. When using the interactions inside a batch to update the memory and then predicting interactions based on this memory, we are trying to predict interactions that were already shown to the model, thus resulting in information leakage. However, if we first predict interactions and update the memory afterward, all memory modules, i.e. the lower path in Fig.2, are not trained as they do not contribute to the loss in any way and therefore do not receive a gradient. This problem is solved by always calculating embeddings based on the messages of the previous batch. All messages of a batch are saved in an additional module called Raw Message Store to be used in the next iteration to update the memory and produce a prediction. After the prediction is performed the messages of the current batch are saved in the Raw Message Store and the process begins anew. This improved information flow is illustrated in Fig. 3.
Fig.3: Improved flow of information in the TGN model. The numbers show the order of operations.
Results & Discussion
The authors compare their method with three other state-of-the-art models for dynamic graph learning. Jodie [8] is a specific instance of TGN with the main difference being that it does not use attention to compute an embedding, but employs time projection \mathbf{z}_i(t) = (1 + \Delta t \mathbf{w}) \circ \mathbf{s}_i(t), where \textbf{w} is a learnable vector. TGAT [7] is also a special case of TGN that does not use any memory-related modules but uses a very similar attention mechanism for the embedding. Finally, DyRep [9] also holds node memory, but in contrast to TGN, it additionally inputs an attention-based neighborhood representation to the messaging function and simply uses the memory state as an embedding where TGN employs neighborhood attention.
Fig. 4: Comparison against baselines and ablation study.
Fig. 4 shows the precision of the methods for the task of edge prediction on the Wikipedia dataset [10] plotted against the time per epoch. The proposed model TGN-attn significantly outperforms all baselines. Additionally, this figure shows the benefit of all TGN modules. All models that use memory perform better than TGN-nomem showing the importance of retaining long-term information. The memory staleness problem can also be seen in this plot, as TGN-id, which simply uses the memory as an embedding, does not reach results similar to attention-based embeddings. As an interesting finding, while TGAT's performance suffers more than 10% by reducing the number of attention layers to one from the normal two, this change results in almost no difference in TGN-attn vs. TGN-2l. This stems from the fact, that a node's neighborhood is introduced into its memory via messages. A single hop in TGN-attn therefore already accesses information from a broader neighborhood. This allows for far faster training speed in comparison to TGAT without a great loss in performance.
Conclusion
This work introduced TGNs as a basic framework for working with continuous-time dynamic graphs, which are prevalent representations of data, but have not received very much attention in research. Long-term information is captured in the form of node-wise memory and the staleness problem is solved by temporal graph attention. The method achieves state-of-the-art results both in performance and efficiency. The benefit of all building blocks can be shown in a comprehensive ablation study.
Review
There were some aspects that I found missing in the paper. First, the experimental setup was not explained in the paper itself but only described as closely following a referenced work. However, I would have liked to see it in this paper as well, as some aspects, e.g. how training, validation, and test set are divided in a temporal graph dataset or how a fair comparison with static graph methods is done, are in my opinion quite essential to understand the method and are not immediately obvious for a reader new to the field. Second, when looking at the results for the different datasets, the performance improvement to previous methods always lies in the lower single digits, except for the Twitter dataset where the performance boost is up to 13%. I would have liked the authors to acknowledge this and to provide an explanation why the method excels on the Twitter dataset.
Despite these minor points of critique, I believe this paper lays important groundwork inside the novel field of dynamic graph learning that itself is part of the novel field of graph learning in general. Instead of simply publishing a model that shows impressive performance (which it undoubtedly does) they structured their method in a way to establish a common ground between existing methods and allow further research to build on their framework, which is an important goal in a young field. The text is well structured by first explaining all modules in the simplified model that is easier to understand and only at a later stage of the paper show the final model that is a bit more complex but does not suffer from missing gradients or information leakage. The methodology itself is very interesting and all building blocks are well motivated and evaluated leaving the reader with a good insight into the reasons why they are needed and why they are working.
References
[1] Bronstein, M. M., Bruna, J., LeCun, Y., Szlam, A. and Vandergheynst, P., 2017. Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine, 34(4), pp.18-42.
[2] Rossi, E., Chamberlain, B., Frasca, F., Eynard, D., Monti, F. and Bronstein, M., 2020. Temporal graph networks for deep learning on dynamic graphs. arXiv preprint arXiv:2006.10637.
[3] Kazemi, S.M., Goel, R., Jain, K., Kobyzev, I., Sethi, A., Forsyth, P. and Poupart, P., 2020. Representation Learning for Dynamic Graphs: A Survey. Journal of Machine Learning Research, 21(70), pp.1-73.
[4] Hochreiter, S. and Schmidhuber, J., 1997. Long short-term memory. Neural computation, 9(8), pp.1735-1780.
[5] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H. and Bengio, Y., 2014. Learning phrase representations using RNN encoder-decoder for statistical machine translation. EMNLP, pp.1724–1734.
[6] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L. and Polosukhin, I., 2017. Attention is all you need. NIPS, pp.5998–6008.
[7] Xu, D., Ruan, C., Korpeoglu, E., Kumar, S. and Achan, K., 2020. Inductive representation learning on temporal graphs. International Conference on Learning Representations.
[8] Kumar, S., Zhang, X. and Leskovec, J., 2019, July. Predicting dynamic embedding trajectory in temporal interaction networks. Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp.1269-1278.
[9] Trivedi, R., Farajtabar, M., Biswal, P. and Zha, H., 2019, May. Dyrep: Learning representations over dynamic graphs. International Conference on Learning Representations.
[10] Wikipedia edit history dump. https://meta.wikimedia.org/wiki/Data_dumps.