-
Notifications
You must be signed in to change notification settings - Fork 215
Closed
Description
Hello, I found a bug in the layers.conv module.
The function message_and_aggregate in almost all the models derived from the MessagePassingBase class is inconsistent with the function message and aggerate.
Take the RelationalGraphConv module as an example. In the message function, if there exists the edge feature, it first will be transformed by the self.edge_linear function:
def message(self, graph, input):
node_in = graph.edge_list[:, 0]
message = input[node_in]
if self.edge_linear:
message += self.edge_linear(graph.edge_feature.float())
return messageHowever, in the function message_and_aggregate, the edge_feature will be transformed according to its feature dimension change, which is inconsistent with the message and aggregate functions defined in the previous:
if self.edge_linear:
edge_input = graph.edge_feature.float()
if self.edge_linear.in_features > self.edge_linear.out_features:
edge_input = self.edge_linear(edge_input)
edge_weight = edge_weight.unsqueeze(-1)
edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
dim_size=graph.num_node * graph.num_relation)
if self.edge_linear.in_features <= self.edge_linear.out_features:
edge_update = self.edge_linear(edge_update)Metadata
Metadata
Assignees
Labels
No labels