Skip to content

Necessary modification  #6

@LiamLYJ

Description

@LiamLYJ

Hi~
I made some necessary modification to run your code. The most important part is in pintrend.py:

def inference(self, x, res2, out):

    B = x.shape[0]
    _, C_res2, H_res2, W_res2 = res2.shape

    while out.shape[-1] != x.shape[-1]:
        # N = out.shape[-2] * out.shape[-1]
        out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)

        _, C_out, H_out, W_out = out.shape
        points = sampling_points(out, training=False, N=4048)
        coarse = torch.gather(out.view(B, C_out, -1), 2,
                              points.unsqueeze(1).expand(-1, C_out, -1))

        stride_y = H_out // H_res2 
        stride_x = W_out // W_res2
        points_index_x = points // W_out // stride_x
        points_index_y = points % W_out // stride_y
        res2_points = (points_index_x * W_res2 + points_index_y).long()

        fine = torch.gather(res2.view(B, C_res2, -1), 2,
                            res2_points.unsqueeze(1).expand(-1, C_res2, -1))
        feature_representation = torch.cat([coarse, fine], dim=1)

        rend = self.mlp(feature_representation)

        out = out.view(B, C_out, -1).scatter_(2, points.unsqueeze(1).expand(-1, C_out, -1), rend)
        out = out.view(B, C_out, H_out, W_out)

Can u grant me to the acccess rights, so I can make a PR. thx a lot
I am interested in implementing it in maskrcnn
@zsef123

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions