We will walk through the implementation of the GNN model in this section!

Helper class

We first define a class for Multi-Layer Perceptron (MLP). This class generates an MLP given the width and the depth of it. Because MLPs are used in several places of the GNN, this helper class will make the code cleaner.

GNN layers

In the following code block, we implement one type of GNN layer named InteractionNetwork (IN), which is proposed by the paper Interaction Networks for Learning about Objects, Relations and Physics.

For a graph \(G\), let the feature of node \(i\) be \(v_i\), and the feature of edge \((i, j)\) be \(e_{i, j}\). There are three stages for IN to generate new features of nodes and edges.

  1. Message generation. If there is an edge pointing from node \(i\) to node \(j\), node \(i\) sends a message to node \(j\). The message carries the information of the edge and its two nodes, so it is generated by the following equation \(\mathrm{Msg}_{i,j} = \mathrm{MLP}(v_i, v_j, e_{i,j})\).
  2. Message aggregation. In this stage, each node of the graph aggregates all the messages that it received to a fixed-sized representation. In the IN, aggregation means summing all the messages up, i.e., \(\mathrm{Agg}_i=\sum_{(j,i)\in G}\mathrm{Msg}_{i,j}\).
  3. Update. Finally, we update features of nodes and edges with the results of previous stages. For each edge, its new feature is simply the sum of its old feature and the correspond message, i.e., \(e'_{i,j}=e_{i,j}+\mathrm{Msg}_{i,j}\). For each node, the new feature is determined by its old feature and the aggregated message, i.e., \(v'_i=v_i+\mathrm{MLP}(v_i, \mathrm{Agg}_i)\).

In PyG, GNN layers are implemented as subclass of MessagePassing. We need to override three critical functions to implement our InteractionNetwork GNN layer. Each function corresponds to one stage of the GNN layer.

  1. message() -> message generation

    This function controls how a message is generated on each edge of the graph. It takes three arguments: (1) x_i, features of the source nodes; (2) x_j, features of the target nodes; and (3) edge_feature, features of the edges themselves. In the IN, we simply concatenate all these features and generate the messages with an MLP.

  2. aggregate() -> message aggregation

    This function aggregates messages for nodes. It depends on two arguments: (1) inputs, messages; and (2) index, the graph structure. We handle over the task of message aggregation to the function torch_scatter.scatter and specifies in the argument reduce that we want to sum messages up. Because we want to retain messages themselves to update edge features, we return both messages and aggregated messages.

  3. forward() -> update

    This function puts everything together. x is the node features, edge_index is the graph structure and edge_feature is edge features. The function MessagePassing.propagate invokes functions message and aggregate for us. Then, we update node features and edge features and return them.

The GNN

Now its time to stack GNN layers to a GNN. Besides GNN layers, there are pre-processing and post-processing blocks in the GNN. Before GNN layers, input features are transformed by MLP so that the expressiveness of GNN is improved without increasing GNN layers. After GNN layers, final outputs (accelerations of particles in our case) are extracted from features generated by GNN layers to meet the requirement of the task.

 
©  |   Cornell University    |   Center for Advanced Computing    |   Copyright Statement    |   Access Statement
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)