Graph neural networks at scale

Introduction

GNNs are unique because they leverage connectivity in addition to a typical \((X, y)\) pairs for supervised learning. A GNN utilizes a graph \(g = (X,A,y)\) to conduct a task (e.g. node classification), where \(X \in \mathbf{R}^{n\times d}\) are node features, \(A \in \mathbf{R}^{n \times n}\) are adjacencies, and \(y \in \mathbf{R}^{n,k}\) are labels. Often, an adjacency matrix is defined as \(A_{v,u} = \begin{cases}1 & u \xrightarrow{} v\\ 0 & otherwise \end{cases}\). However, it is not rare to see an adjacency matrix where the value is a float scaler or vector.

Graph convolutional networks (Kipf and Welling, 2017)

At every layer of the GNN, we transform the information at node level into a hidden represenetation. We stack the transformations into multiple layers and then use the final hidden space to map to our final value, based on the task.

\[H^l = \sigma(S H^{(l - 1)} W_l)\]

where \(S H^{(l - 1)}\) is the aggregation (in this case, average since \(S\) is a symmetrically normalized ajacency matrix) of current representations of neighbor nodes.

This is the typical process for all GNNs. The variants change what message to send, how to aggregate messages, and how to transform the aggregated message. These changes are usually motivated by additional knowledge abou tthe process generating the graph.

This shared structure is formalized in:

Message Passing Neural Networks (Gilmer et al 2017)

For each incoming edge, generate a message to send based on the current representation of the endpoints and edge (if any)

\[m_{u\xrightarrow{} v}^{(l+1)} \overset{\Delta}{=} \phi(H_u^l, H_v^l, H_{u\xrightarrow{}v}^l)\]

Aggregate all incoming messages at a node to reduce to a single vector

\[\bar{m}_v^{(l+1)} \overset{\Delta}{=} \rho(\{m_{u\xrightarrow{}v}^{(l+1)} : u \in \mathcal{N}(v) \})\]

Apply a transformation to the aggregated messages to uipdate the node’s representation

\[H_v^{(l+1)} \overset{\Delta}{=} \psi(H_v^l, \bar{m}_v^{(l+1)})\]

GNN Scaling issues

This section will focus on scaling GNNs to large data sets. In general GNNs are not very large models (in terms of # of parameters), especially compared to the recent language models. The key bottleneck is not so much the size of the model but fitting the data into GPU memory. GPU memory pressure comes from the recursive structure of the memory passing computation that happens at each layer.

What needs to be on the GPU when training a GNN?

  • Model parameters (\(W^1\), \(W^2\))

  • Input data (e.g., X, y)

  • Intermediate outputs (e.g. \(H^1\), \(H^2\))

  • Gradients

Typically, mini-batch training is used to sample a small number of samples from the training dataset at every iteration before computing a model update. That reduces the size of the input data during training.

Overview

In GNNs, minibatches don’t give us an easy win for graphs. In order to compute the output of the given node, we have to collect nodes from its neighborhood (e.g. k-hop neighbors). This can produce a large portion of the graph if in a dense graph. “Small world” says that the distance between any two nodes in a graph only grows logarithmically with the number of nodes in the graph (e.g. if we have a graph with 10x the number of nodes, we only expect to increase the distance between two nodes by 1 hop). So, the implication is that the K-hop receptive fields for a minibatch of nodes can increase the memory requirements on the GPU by serveral orders of magnitude (beyond what we’d need just for target nodes). So in practice, we don’t get the same benefit of using minibatches directly in GNNs because of the inter-dependence of samples. This is called the receptive field problem.

Another way to get around GPU issues is to use distributed GPU learning. Data parallelism will partition the graph into parts and then pass each partition to a separate GPU. The same model architecture is trained on each partition and every once in a while they are synced by averaging (for instance) the model parameters across the different model instances. In general, this is not easy to implement because if we partition the nodes, then there will be edges between nodes on different partitions. So, what do we do if we have to collect a message from a neighbor that lives on another partition? We can query data from different GPUs but that’s expensive. In general, data parallelism is hard to implement. As a note, if the graph is small enough to fit in host CPU memory for each GPU, then data parallelism is great to speed up GNN training.

Another way is model and pipeline parallelism. This is often used when a model is really large (for instance, large language models).

An alternative is to take a subgraph of the graph in each batch and use that for training.

Solutions

Message Flow Graphs

GraphSage (Hamilton et al 2017)

Starting with the last layer, sample k neighbor nodes to use from previous layer to compute the representation of the target node

Subgraphs

ClusterGCN (Chiang et al. 2019)

Partition nodes into K groups to find “dense subgraphs” (paper and DGL use METIS). For instance, we partition our graph into K groups, then ClusterGCN samples k groups and induce subgraph. The intuition is that if we only pass messages between nodes from one partition, then we may get a pretty good approx of the true output becasue this is a dense subset of the graph. But, occassionaly we want to reach across partitions and include info from other partions. So it controls the number of nodes on the GPU but also allows us to get a relative representation of true nodes from the graph.

GraphSAINT: Node Sampler (Zeng et al 2020)

Sample nodes proportional to their out-degree and induce subgraph

GraphSAINT: Edge Sampler (Zeng et al 2020)

Instead of sampling nodes and then including any edges between that usbset of nodes, we are going to sample edges directly. So, sample edges with probability which is proprotional to the sum of the reciprocal of the source nodes’ out degree and the dstination nodes’ in-degree.

\[p(e) \propto \frac{1}{d_{out}}(e_u) + \frac{1}{d_{in}(e_v)}\]

Once we have a sample of edges, we are going to induce a subgraph from nodes that appear in at leat one sampled edge.

GraphSAINT: Random Walk Sampler (Zeng et al. 2020)

Randomly choose \(n\) root nodes (with replacement) and then start length \(k\) random walks from each root. Then, we induce a subgraph from the union of all visited nodes in the subgraph including the roots.

ShaDowKHop (Zeng et al 2021)

Induce a subgraph from K-hop neighborhood of sampled root nodes. E.g. sample 2 neighbors at layer 1 and 2 neighbors at layer 2, and induce subgraph. This is really similar to GraphSage. The key difference is that in GraphSage, we use the message flow graphs to only pass messages that we need to pass to compute outputs of target nodes. In ShaDowKHop, we keep the extraneous edges that might exist between two nodes that are same instance from target node and we allow messages to be exchanged.

Which to use?

Hard to predict.. but if the graph is too large, use GraphSage with as many neighbors as possible before you run out of GPU memory. Once refining hyperparameters, treat the minibatch sampler as a hyperparameter and run experiments to see which is best for problem.

Recent solutions in literature

Some recent extensions ahve been proposed in literature, namely

  • reduce the variance of sample approximations

  • maximize “embedding utilization” (use each node more than once if possible) - using layer sampling..

  • learn to select which neighbors to include in the approximation based on the loss (similar to graph attention networks but avoiding having to compute on full graph)

Overview

Batch training

Batch training is written mathematically usually as follows:

\(\mathcal{V}\): Node set

\(|\mathcal{V}|\): Number of nodes in the graph

\(Z_v^L\): Model output for node \(v\)

\(y_v\): Label for node \(v\)

\(l\): loss function

\[\mathcal{L}_{\hbox{Batch}} = \frac{1}{|\mathcal{V}|} \sum_{v\in \mathcal{V}} l(Z_v^L, y_v)\]

Minibatch training w/ full neighbors

\(\mathcal{B}\): Minibatch of randomly selected nodes from \(\mathcal{V}\)

\(|\mathcal{B}|\): Number of sampled nodes in minibatch

\[\mathcal{L}_{MBFN}(B) = \frac{1}{|\mathcal{B}|} \sum_{v\in \mathcal{B}} l(Z_v^L, y_v)\]

MBFN is unbiased.

\[E_\mathcal{B}[\mathcal{L}_{MBFN}(\mathcal{B})] = \frac{1}{|\mathcal{B}|} \sum_{v \in \mathcal{B}} E_v [ l (Z_v^L, y_v)] = \frac{1}{|\mathcal{B}|} \sum_{v \in \mathcal{B}} \mathcal{L}_{Batch} = \mathcal{L}_{Batch}\]

Minibatch training w/ sampled neighbors

\(\hat{\mathcal{N}}_L(v)\): Recursively sampled neighbors from \(v\)’s \(L\)-hop neighborhood

\(\hat{Z}_v^L\): Approximate model output for \(v\) based on sampled neighbors

\[\mathcal{Z}_{MBSN} (\mathcal{B}, \{ \hat{\mathcal{N}}_L(v) : v \in \mathcal{B} \}) = \frac{1}{|\mathcal{B}|} \sum_{v \in \mathcal{B}} l (\hat{Z}_v^L, y_v)\]

Variance reduced GCN (Chen et al 2018)

In general, MBSN loss for GCN training is biased because

\[E_{\hat{\mathcal{N}}_L(v)} [ l (\hat{Z}_v^L, y_v) | v ] \neq l (Z_v^L, y_v)\]

If the number of neighbors sampled at each hop from the target node is large, then the bias is small (intuition follows the continuous mapping theorem). A large sample, however, puts more memory pressure on the GPU. The goal is to reduce the bias by reducing the variance without sampling a larger numebr of nodes.

The way they do that is as follows:

Message to node \(v\) at layer \(l\) is:

\[(S H^{(l-1)})_v = \sum_{u \in \mathcal{N}(v)} S_{vu} H_u^{(l-1)}\]

With neighbor sampling, the message is approximated with

\[(S H^{(l-1)})_v \approx \frac{|\mathcal{N}(v)|}{|\mathcal{S}_l(v)|} \sum_{u \in \mathcal{S}_l(v)} S_{vu} \hat{H}_u^{(l-1)}\]

where \(\mathcal{S}_l(v)\) is the sample of neighbors of a node \(v\). The approximation of the neighbor’s representations are used also.

To reduce the variance, Chen et al. propose a control variate based on the historical embeddings of a node’s neighbors at the previous layer.

They noticed we can break each previous representation into a historical value and the difference between the historical and current value:

\[(S H^{(l-1)})_v = \sum_{u \in \mathcal{N}(v)} S_{vu} \Delta H_u^{(l-1)} + \sum_{u \in \mathcal{N}(v)} S_{vu} \hat{H}_u^{(l-1)}\]

This history value \(\hat{H}_u^{(l-1)}\) is assumed to be known. We use a cache of the \(l-1\) layer representations for all nodes. Then the model updates are pushed. Then we compute \(\Delta H_u^{(l-1)} = H_u^{l} - \hat{H}_u^{(l-1)}\) as the difference.

If we subset just the deltas, then we sample vectors with smaller overall norm, which means a smaller variance. There’s also a control variate effect. Get smaller as training converges.

\[(S H^{(l-1)})_v \approx \frac{|\mathcal{N}(v)|}{|\mathcal{S}_l(v)|} \sum_{u \in \mathcal{S}_l(v)} S_{vu} \Delta \hat{H}_u^{(l-1)} + \sum_{u \in \mathcal{N}(v)} S_{vu} \hat{H}_u^{(l-1)}\]

This uses more CPU memory (from the cache) instead of GPU memory.

Adaptive Sampling GCN (Huang et al 2018)

Even when we use GraphSage to subsample the number of neighbors at each layer, the number of intermediate nodes is still exponential in the number of layers (and we need large samples to reduce bias). Key goal: Make the number of intermediate nodes linear in the number of layers by selecting a fixed set of nodes in the previous layer for approximating messages for all nodes in the current layer. So, instead of having a multiplicative effect at every layer that we add (causing exponential effect), we add same number of nodes at every layer independent of the number of nodes at previous layer, giving us a linear dependence. So, how do we choose a good fixed set of nodes at each layer to help approximate the representation of the layer above?

\[\begin{split}\begin{align*} (S H^{l-1})_v &= \sum_{u \in \mathcal{N}(v)} S_{vu} H_u^{(l-1)}\\ &= |\mathcal{N}(v)| \sum_{u \in \mathcal{N}(v)} \frac{1}{|\mathcal{N}(v)|} S_{vu} H_u^{(l-1)}\\ & \hbox{$p$: Uniform distribution over neighbors}\\ &= |\mathcal{N}(v)| \sum_{u \in \mathcal{V}} p(u | v) S_{vu} H_u^{(l-1)}\\ &= |\mathcal{N}(v)| E_p [S_{vu} H_u^{(l-1)}]\\ & \hbox{$q_l$: Unspecified proposal distribution over all neighbors of node at layer $l$}\\ &= |\mathcal{N}(v)| E_{ql} [\hbox{p(u|v)}{q_l(u)} S_{vu} H_u^{(l-1)}]\\ \end{align*}\end{split}\]

The proposal distribution distribution that minimizes the variance for node \(v\):

\[q_l^{\star} (u) \propto p(u|v) \times ||H_u^{(l-1)}||_2^2\]

To minimize variance, we prefer to sample nodes \(u\) that have a large L2 norm in representation on the previous layer.

There are two problems with this:

    1. the optimal proposal distribution is specific to \(v\) (we want to sample from a common \(q\), not specific neighbors to node \(v\) as there could be duplicates)

    1. we need to compute the hidden representation at the previous layer for all neighbors of all nodes at the current layer (the thing we were trying to avoid!)

To sidestep these two issues, they introduce this proposal distribution:

\[q_l(u) \propto \sum_{v in layer l} p(u|v) W x_u\]

where \(W\) is the learnable \(1 \times d\) matrix. So, we are trying to predict what the norm of the hidden representation at layer \(l-1\) is, given the known feature values for node \(u\).

They augment the learning objective by minimizing the minibatch with sampled neighbors objective (from before) plus a variance penalty:

\[\frac{1}{|\mathcal{B}|} \sum_{v \in \mathcal{B}} l (\hat{Z}_v^L, y_v) + \lambda Var(\hat{Z}_v^L)\]

The variance is only computed in the final layer (not computing variance at each layer… just at last layer \(L\)). This is interesting because 1) we share/compact the nodes that we use to help approximate the last layer, 2) having a learnable distribution when choosing neighbors at every layer as we go down the stack.

Performance Adaptive Sampling (Yoon et al 2021)

Like AS-GCN, this method introduces a learnable distribution over neighbors. The key difference is PASS optimizes the distribution to directly improve performance instead of minimizing variance.

Each pre-activation hidden representation is

\[Z_v^l = \frac{1}{k} \sum_{u \in \mathcal{S}_l(v)} W_l^T H_u^{(l-1)}\]

where \(\mathcal{S}_l(v)\) is a set of neighbors sampled using a learned distribution. Specifically, the learned policy is a mixture of random sampling and “affinity” sampling. So this is a tradeoff of exploration and exploitation.

The proposal distribution at layer \(l\) over \(u\) given \(v\) is:

\[\begin{split}\begin{align*} q^l(u|v) &= a_s \times [q_a^l (u | v), q_r^l(u|v) ]^\intercal\\ q_a^l(u|v) &= (W_s H_u^0) \times (W_s H_v^0)\\ q_r^l (u|v) &= \frac{1}{\mathcal{N}(v)}\\ \end{align*}\end{split}\]

where \(a_s\) is an attention matrix. The affinity assigned weight is a dot product of the vectors obtained after passing through a linear transformation of \(W_s\) with the original hidden representations of the nodes. The sampling operation is non-differentiable, so they use REINFORCE to update the sampling distribution parameters, which does not require computing the gradients and using backpropagation.