GNN Model
We will walk through the implementation of the GNN model in this section!
Helper class
We first define a class for Multi-Layer Perceptron (MLP). This class generates an MLP given the width and the depth of it. Because MLPs are used in several places of the GNN, this helper class will make the code cleaner.
GNN layers
In the following code block, we implement one type of GNN layer named InteractionNetwork (IN), which is proposed by the paper Interaction Networks for Learning about Objects, Relations and Physics.
For a graph \(G\), let the feature of node \(i\) be \(v_i\), and the feature of edge \((i, j)\) be \(e_{i, j}\). There are three stages for IN to generate new features of nodes and edges.
- Message generation. If there is an edge pointing from node \(i\) to node \(j\), node \(i\) sends a message to node \(j\). The message carries the information of the edge and its two nodes, so it is generated by the following equation \(\mathrm{Msg}_{i,j} = \mathrm{MLP}(v_i, v_j, e_{i,j})\).
- Message aggregation. In this stage, each node of the graph aggregates all the messages that it received to a fixed-sized representation. In the IN, aggregation means summing all the messages up, i.e., \(\mathrm{Agg}_i=\sum_{(j,i)\in G}\mathrm{Msg}_{i,j}\).
- Update. Finally, we update features of nodes and edges with the results of previous stages. For each edge, its new feature is simply the sum of its old feature and the correspond message, i.e., \(e'_{i,j}=e_{i,j}+\mathrm{Msg}_{i,j}\). For each node, the new feature is determined by its old feature and the aggregated message, i.e., \(v'_i=v_i+\mathrm{MLP}(v_i, \mathrm{Agg}_i)\).
In PyG, GNN layers are implemented as subclass of MessagePassing. We need to override three critical functions to implement our InteractionNetwork GNN layer. Each function corresponds to one stage of the GNN layer.
message()-> message generationThis function controls how a message is generated on each edge of the graph. It takes three arguments: (1)
x_i, features of the source nodes; (2)x_j, features of the target nodes; and (3)edge_feature, features of the edges themselves. In the IN, we simply concatenate all these features and generate the messages with an MLP.aggregate()-> message aggregationThis function aggregates messages for nodes. It depends on two arguments: (1)
inputs, messages; and (2)index, the graph structure. We handle over the task of message aggregation to the functiontorch_scatter.scatterand specifies in the argumentreducethat we want to sum messages up. Because we want to retain messages themselves to update edge features, we return both messages and aggregated messages.forward()-> updateThis function puts everything together.
xis the node features,edge_indexis the graph structure andedge_featureis edge features. The functionMessagePassing.propagateinvokes functionsmessageandaggregatefor us. Then, we update node features and edge features and return them.
The GNN
Now its time to stack GNN layers to a GNN. Besides GNN layers, there are pre-processing and post-processing blocks in the GNN. Before GNN layers, input features are transformed by MLP so that the expressiveness of GNN is improved without increasing GNN layers. After GNN layers, final outputs (accelerations of particles in our case) are extracted from features generated by GNN layers to meet the requirement of the task.
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)