Skip to content

Error in the layers.conv module #140

@mrzzmrzz

Description

@mrzzmrzz

Hello, I found a bug in the layers.conv module.

The function message_and_aggregate in almost all the models derived from the MessagePassingBase class is inconsistent with the function message and aggerate.

Take the RelationalGraphConv module as an example. In the message function, if there exists the edge feature, it first will be transformed by the self.edge_linear function:

 def message(self, graph, input):
        node_in = graph.edge_list[:, 0]
        message = input[node_in]
        if self.edge_linear:
            message += self.edge_linear(graph.edge_feature.float())
        return message

However, in the function message_and_aggregate, the edge_feature will be transformed according to its feature dimension change, which is inconsistent with the message and aggregate functions defined in the previous:

 if self.edge_linear:
    edge_input = graph.edge_feature.float()
    if self.edge_linear.in_features > self.edge_linear.out_features:
        edge_input = self.edge_linear(edge_input)
    edge_weight = edge_weight.unsqueeze(-1)
    edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                              dim_size=graph.num_node * graph.num_relation)
    if self.edge_linear.in_features <= self.edge_linear.out_features:
        edge_update = self.edge_linear(edge_update)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions