-
Notifications
You must be signed in to change notification settings - Fork 75
Open
Description
Hi~
I made some necessary modification to run your code. The most important part is in pintrend.py:
def inference(self, x, res2, out):
B = x.shape[0]
_, C_res2, H_res2, W_res2 = res2.shape
while out.shape[-1] != x.shape[-1]:
# N = out.shape[-2] * out.shape[-1]
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)
_, C_out, H_out, W_out = out.shape
points = sampling_points(out, training=False, N=4048)
coarse = torch.gather(out.view(B, C_out, -1), 2,
points.unsqueeze(1).expand(-1, C_out, -1))
stride_y = H_out // H_res2
stride_x = W_out // W_res2
points_index_x = points // W_out // stride_x
points_index_y = points % W_out // stride_y
res2_points = (points_index_x * W_res2 + points_index_y).long()
fine = torch.gather(res2.view(B, C_res2, -1), 2,
res2_points.unsqueeze(1).expand(-1, C_res2, -1))
feature_representation = torch.cat([coarse, fine], dim=1)
rend = self.mlp(feature_representation)
out = out.view(B, C_out, -1).scatter_(2, points.unsqueeze(1).expand(-1, C_out, -1), rend)
out = out.view(B, C_out, H_out, W_out)
Can u grant me to the acccess rights, so I can make a PR. thx a lot
I am interested in implementing it in maskrcnn
@zsef123
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels