99"""
1010
1111import torch .nn as nn
12-
1312from torch import Tensor
1413from torch_geometric .data import Data
1514
2423 RRWPLinearNodeEncoder ,
2524 LinearNodeEncoder ,
2625 LinearEdgeEncoder ,
26+ RWSELinearNodeEncoder ,
2727)
2828
2929
@@ -38,6 +38,7 @@ def __init__(
3838 self ,
3939 nb_inputs : int ,
4040 hidden_dim : int ,
41+ nb_outputs : int = 1 ,
4142 ksteps : int = 21 ,
4243 n_layers : int = 10 ,
4344 n_heads : int = 8 ,
@@ -56,13 +57,15 @@ def __init__(
5657 enable_edge_transform : bool = True ,
5758 pred_head_layers : int = 2 ,
5859 pred_head_activation : nn .Module = nn .ReLU ,
60+ pred_head_pooling : str = "mean" ,
61+ position_encoding : str = "NoPE" ,
5962 ):
6063 """Construct `GRIT` model.
6164
6265 Args:
6366 nb_inputs: Number of inputs.
6467 hidden_dim: Size of hidden dimension.
65- dim_out : Size of output dimension.
68+ nb_outputs : Size of output dimension.
6669 ksteps: Number of random walk steps.
6770 n_layers: Number of GRIT layers.
6871 n_heads: Number of heads in MHA.
@@ -82,20 +85,36 @@ def __init__(
8285 enable_edge_transform: Apply transformation to edges.
8386 pred_head_layers: Number of layers in the prediction head.
8487 pred_head_activation: Prediction head activation function.
88+ pred_head_pooling: Pooling function to use for the prediction head,
89+ either "mean" (default) or "add".
90+ position_encoding: Method of position encoding.
8591 """
86- super ().__init__ (nb_inputs , hidden_dim // 2 ** pred_head_layers )
87-
88- self .node_encoder = LinearNodeEncoder (nb_inputs , hidden_dim )
89- self .edge_encoder = LinearEdgeEncoder (hidden_dim )
90-
91- self .rrwp_abs_encoder = RRWPLinearNodeEncoder (ksteps , hidden_dim )
92- self .rrwp_rel_encoder = RRWPLinearEdgeEncoder (
93- ksteps ,
94- hidden_dim ,
95- pad_to_full_graph = pad_to_full_graph ,
96- add_node_attr_as_self_loop = add_node_attr_as_self_loop ,
97- fill_value = fill_value ,
98- )
92+ super ().__init__ (nb_inputs , nb_outputs )
93+ self .position_encoding = position_encoding .lower ()
94+ if self .position_encoding == "nope" :
95+ encoders = [
96+ LinearNodeEncoder (nb_inputs , hidden_dim ),
97+ LinearEdgeEncoder (hidden_dim ),
98+ ]
99+ elif self .position_encoding == "rrwp" :
100+ encoders = [
101+ LinearNodeEncoder (nb_inputs , hidden_dim ),
102+ LinearEdgeEncoder (hidden_dim ),
103+ RRWPLinearNodeEncoder (ksteps , hidden_dim ),
104+ RRWPLinearEdgeEncoder (
105+ ksteps ,
106+ hidden_dim ,
107+ pad_to_full_graph = pad_to_full_graph ,
108+ add_node_attr_as_self_loop = add_node_attr_as_self_loop ,
109+ fill_value = fill_value ,
110+ ),
111+ ]
112+ elif self .position_encoding == "rwse" :
113+ encoders = [
114+ LinearNodeEncoder (nb_inputs , hidden_dim - (ksteps - 1 )),
115+ RWSELinearNodeEncoder (ksteps - 1 , hidden_dim ),
116+ ]
117+ self .encoders = nn .ModuleList (encoders )
99118
100119 layers = []
101120 for _ in range (n_layers ):
@@ -120,19 +139,16 @@ def __init__(
120139 self .layers = nn .ModuleList (layers )
121140 self .head = SANGraphHead (
122141 dim_in = hidden_dim ,
142+ dim_out = nb_outputs ,
123143 L = pred_head_layers ,
124144 activation = pred_head_activation ,
145+ pooling = pred_head_pooling ,
125146 )
126147
127148 def forward (self , x : Data ) -> Tensor :
128149 """Forward pass."""
129- # Apply linear layers to node/edge features
130- x = self .node_encoder (x )
131- x = self .edge_encoder (x )
132-
133- # Encode with RRWP
134- x = self .rrwp_abs_encoder (x )
135- x = self .rrwp_rel_encoder (x )
150+ for encoder in self .encoders :
151+ x = encoder (x )
136152
137153 # Apply GRIT layers
138154 for layer in self .layers :
0 commit comments