Skip to content

Commit 68634cc

Browse files
committed
Improvements to GRIT arguments, added new position encodings, fixed pickling
1 parent 48d5e85 commit 68634cc

File tree

7 files changed

+315
-31
lines changed

7 files changed

+315
-31
lines changed

src/graphnet/models/components/embedding.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,38 @@ def forward(self, data: Data) -> Data:
367367

368368
data.edge_index, data.edge_attr = out_idx, out_val
369369
return data
370+
371+
372+
class RWSELinearNodeEncoder(LightningModule):
373+
"""Random walk structural node encoding."""
374+
375+
def __init__(
376+
self,
377+
emb_dim: int,
378+
out_dim: int,
379+
use_bias: bool = False,
380+
):
381+
"""Construct `RWSELinearEdgeEncoder`.
382+
383+
Args:
384+
emb_dim: Embedding dimension.
385+
out_dim: Output dimension.
386+
use_bias: Apply bias to linear layer.
387+
"""
388+
super().__init__()
389+
390+
self.emb_dim = emb_dim
391+
self.out_dim = out_dim
392+
393+
self.encoder = nn.Linear(emb_dim, out_dim, bias=use_bias)
394+
395+
def forward(self, data: Data) -> Data:
396+
"""Forward pass."""
397+
rwse = data.rwse
398+
x = data.x
399+
400+
rwse = self.encoder(rwse)
401+
402+
data.x = torch.cat((x, rwse), dim=1)
403+
404+
return data

src/graphnet/models/components/layers.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch.nn as nn
77
from torch.functional import Tensor
88
from torch_geometric.nn import EdgeConv
9-
from torch_geometric.nn.pool import knn_graph, global_add_pool
9+
from torch_geometric.nn.pool import (
10+
knn_graph,
11+
global_mean_pool,
12+
global_add_pool,
13+
)
1014
from torch_geometric.typing import Adj, PairTensor
1115
from torch_geometric.nn.conv import MessagePassing
1216
from torch_geometric.nn.inits import reset
@@ -893,14 +897,13 @@ def forward(self, data: Data) -> Data:
893897
x = self.fc1_x(x)
894898
if e_attn_out is not None:
895899
e = e_attn_out.flatten(1)
896-
# TODO: Make this a nn.Dropout in initialization -PW
897900
e = self.dropout2(e)
898901
e = self.fc1_e(e)
899902

900903
if self.residual:
901904
if self.rezero:
902905
x = x * self.alpha1_x
903-
x = x_attn_residual + x # residual connection
906+
x = x_attn_residual + x
904907

905908
if e is not None:
906909
if self.rezero:
@@ -946,34 +949,46 @@ class SANGraphHead(LightningModule):
946949
def __init__(
947950
self,
948951
dim_in: int,
952+
dim_out: int = 1,
949953
L: int = 2,
950954
activation: nn.Module = nn.ReLU,
955+
pooling: str = "mean",
951956
):
952957
"""Construct `SANGraphHead`.
953958
954959
Args:
955960
dim_in: Input dimension.
961+
dim_out: Output dimension.
956962
L: Number of hidden layers.
957963
activation: Activation function.
964+
pooling: Pooling method.
958965
"""
959966
super().__init__()
960-
self.pooling_fun = global_add_pool
967+
if pooling == "mean":
968+
self.pooling_fun = global_mean_pool
969+
elif pooling == "add":
970+
self.pooling_fun = global_add_pool
971+
else:
972+
raise RuntimeError("Currently supports only 'add' or 'mean'.")
961973

962974
fc_layers = [
963975
nn.Linear(dim_in // 2**n, dim_in // 2 ** (n + 1), bias=True)
964976
for n in range(L)
965977
]
978+
assert dim_in // 2**L >= dim_out, "Too much dim reduction!"
979+
fc_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True))
966980
self.fc_layers = nn.ModuleList(fc_layers)
967981
self.L = L
968982
self.activation = activation()
969-
self.dim_out = dim_in // 2**L
983+
self.dim_out = dim_out
970984

971985
def forward(self, data: Data) -> Tensor:
972986
"""Forward Pass."""
973987
graph_emb = self.pooling_fun(data.x, data.batch)
974988
for i in range(self.L):
975989
graph_emb = self.fc_layers[i](graph_emb)
976990
graph_emb = self.activation(graph_emb)
991+
graph_emb = self.fc_layers[self.L](graph_emb)
977992
# Original code applied a final linear layer to project to dim_out,
978993
# but we will let the Task layer do that.
979994
return graph_emb

src/graphnet/models/gnn/grit.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"""
1010

1111
import torch.nn as nn
12-
1312
from torch import Tensor
1413
from torch_geometric.data import Data
1514

@@ -24,6 +23,7 @@
2423
RRWPLinearNodeEncoder,
2524
LinearNodeEncoder,
2625
LinearEdgeEncoder,
26+
RWSELinearNodeEncoder,
2727
)
2828

2929

@@ -38,6 +38,7 @@ def __init__(
3838
self,
3939
nb_inputs: int,
4040
hidden_dim: int,
41+
nb_outputs: int = 1,
4142
ksteps: int = 21,
4243
n_layers: int = 10,
4344
n_heads: int = 8,
@@ -56,13 +57,15 @@ def __init__(
5657
enable_edge_transform: bool = True,
5758
pred_head_layers: int = 2,
5859
pred_head_activation: nn.Module = nn.ReLU,
60+
pred_head_pooling: str = "mean",
61+
position_encoding: str = "NoPE",
5962
):
6063
"""Construct `GRIT` model.
6164
6265
Args:
6366
nb_inputs: Number of inputs.
6467
hidden_dim: Size of hidden dimension.
65-
dim_out: Size of output dimension.
68+
nb_outputs: Size of output dimension.
6669
ksteps: Number of random walk steps.
6770
n_layers: Number of GRIT layers.
6871
n_heads: Number of heads in MHA.
@@ -82,20 +85,36 @@ def __init__(
8285
enable_edge_transform: Apply transformation to edges.
8386
pred_head_layers: Number of layers in the prediction head.
8487
pred_head_activation: Prediction head activation function.
88+
pred_head_pooling: Pooling function to use for the prediction head,
89+
either "mean" (default) or "add".
90+
position_encoding: Method of position encoding.
8591
"""
86-
super().__init__(nb_inputs, hidden_dim // 2**pred_head_layers)
87-
88-
self.node_encoder = LinearNodeEncoder(nb_inputs, hidden_dim)
89-
self.edge_encoder = LinearEdgeEncoder(hidden_dim)
90-
91-
self.rrwp_abs_encoder = RRWPLinearNodeEncoder(ksteps, hidden_dim)
92-
self.rrwp_rel_encoder = RRWPLinearEdgeEncoder(
93-
ksteps,
94-
hidden_dim,
95-
pad_to_full_graph=pad_to_full_graph,
96-
add_node_attr_as_self_loop=add_node_attr_as_self_loop,
97-
fill_value=fill_value,
98-
)
92+
super().__init__(nb_inputs, nb_outputs)
93+
self.position_encoding = position_encoding.lower()
94+
if self.position_encoding == "nope":
95+
encoders = [
96+
LinearNodeEncoder(nb_inputs, hidden_dim),
97+
LinearEdgeEncoder(hidden_dim),
98+
]
99+
elif self.position_encoding == "rrwp":
100+
encoders = [
101+
LinearNodeEncoder(nb_inputs, hidden_dim),
102+
LinearEdgeEncoder(hidden_dim),
103+
RRWPLinearNodeEncoder(ksteps, hidden_dim),
104+
RRWPLinearEdgeEncoder(
105+
ksteps,
106+
hidden_dim,
107+
pad_to_full_graph=pad_to_full_graph,
108+
add_node_attr_as_self_loop=add_node_attr_as_self_loop,
109+
fill_value=fill_value,
110+
),
111+
]
112+
elif self.position_encoding == "rwse":
113+
encoders = [
114+
LinearNodeEncoder(nb_inputs, hidden_dim - (ksteps - 1)),
115+
RWSELinearNodeEncoder(ksteps - 1, hidden_dim),
116+
]
117+
self.encoders = nn.ModuleList(encoders)
99118

100119
layers = []
101120
for _ in range(n_layers):
@@ -120,19 +139,16 @@ def __init__(
120139
self.layers = nn.ModuleList(layers)
121140
self.head = SANGraphHead(
122141
dim_in=hidden_dim,
142+
dim_out=nb_outputs,
123143
L=pred_head_layers,
124144
activation=pred_head_activation,
145+
pooling=pred_head_pooling,
125146
)
126147

127148
def forward(self, x: Data) -> Tensor:
128149
"""Forward pass."""
129-
# Apply linear layers to node/edge features
130-
x = self.node_encoder(x)
131-
x = self.edge_encoder(x)
132-
133-
# Encode with RRWP
134-
x = self.rrwp_abs_encoder(x)
135-
x = self.rrwp_rel_encoder(x)
150+
for encoder in self.encoders:
151+
x = encoder(x)
136152

137153
# Apply GRIT layers
138154
for layer in self.layers:

src/graphnet/models/graphs/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,10 @@
66
"""
77

88
from .graph_definition import GraphDefinition
9-
from .graphs import KNNGraph, EdgelessGraph, KNNGraphRRWP
9+
from .graphs import (
10+
KNNGraph,
11+
EdgelessGraph,
12+
KNNGraphRRWP,
13+
KNNGraphRWSE,
14+
KNNGraphNoPE,
15+
)

src/graphnet/models/graphs/edges/edges.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
from torch_geometric.nn import knn_graph, radius_graph
88
from torch_geometric.data import Data
9+
from torch_geometric.utils import to_undirected
10+
from torch_geometric.utils.num_nodes import maybe_num_nodes
911

1012
from graphnet.models.utils import calculate_distance_matrix
1113
from graphnet.models import Model
@@ -111,6 +113,16 @@ def __init__(
111113
def _construct_edges(self, graph: Data) -> Data:
112114
"""Define K-NN edges."""
113115
graph = super()._construct_edges(graph)
116+
117+
if graph.edge_index.numel() == 0: # Check if edge_index is empty
118+
num_nodes = graph.num_nodes
119+
self_loops = torch.arange(num_nodes).repeat(2, 1)
120+
graph.edge_index = self_loops
121+
122+
graph.num_nodes = maybe_num_nodes(graph.edge_index)
123+
graph.edge_index = to_undirected(
124+
graph.edge_index, num_nodes=graph.num_nodes
125+
)
114126
position_data = graph.x[:, self._columns]
115127

116128
src, tgt = graph.edge_index

0 commit comments

Comments
 (0)