Skip to content

Commit 6ba6b7e

Browse files
authored
change the npu code for roi align rotated (#3238)
1 parent a4a884d commit 6ba6b7e

1 file changed

Lines changed: 23 additions & 7 deletions

File tree

mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,26 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
1111
int64_t aligned_height_64 = aligned_height;
1212
int64_t aligned_width_64 = aligned_width;
1313
int64_t sampling_ratio_64 = sampling_ratio;
14+
15+
at::Tensor input_trans = input.permute({0, 2, 3, 1}).contiguous();
16+
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
17+
at::Tensor output_trans = output.permute({0, 2, 3, 1}).contiguous();
18+
1419
OpCommand cmd;
1520
cmd.Name("RoiAlignRotated")
16-
.Input(input)
17-
.Input(rois)
18-
.Output(output)
21+
.Input(input_trans)
22+
.Input(rois_trans)
23+
.Output(output_trans)
1924
.Attr("pooled_h", aligned_height_64)
2025
.Attr("pooled_w", aligned_width_64)
2126
.Attr("spatial_scale", spatial_scale)
2227
.Attr("sampling_ratio", sampling_ratio_64)
2328
.Attr("aligned", aligned)
2429
.Attr("clockwise", clockwise)
2530
.Run();
31+
32+
output_trans = output_trans.permute({0, 3, 1, 2}).contiguous();
33+
output.copy_(output_trans);
2634
}
2735

2836
void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
@@ -33,16 +41,21 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
3341
int64_t aligned_height_64 = aligned_height;
3442
int64_t aligned_width_64 = aligned_width;
3543
int64_t sampling_ratio_64 = sampling_ratio;
44+
45+
at::Tensor top_grad_trans = top_grad.permute({0, 2, 3, 1}).contiguous();
46+
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
47+
at::Tensor bottom_grad_trans = bottom_grad.permute({0, 2, 3, 1}).contiguous();
48+
3649
c10::SmallVector<int64_t, 8> y_grad_shape;
37-
auto shape = bottom_grad.sizes();
50+
auto shape = bottom_grad_trans.sizes();
3851
for (uint64_t i = 0; i < shape.size(); i++) {
3952
y_grad_shape.emplace_back(shape[i]);
4053
}
4154
OpCommand cmd;
4255
cmd.Name("RoiAlignRotatedGrad")
43-
.Input(top_grad)
44-
.Input(rois)
45-
.Output(bottom_grad)
56+
.Input(top_grad_trans)
57+
.Input(rois_trans)
58+
.Output(bottom_grad_trans)
4659
.Attr("y_grad_shape", y_grad_shape)
4760
.Attr("pooled_h", aligned_width_64)
4861
.Attr("pooled_w", aligned_height_64)
@@ -51,6 +64,9 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
5164
.Attr("aligned", aligned)
5265
.Attr("clockwise", clockwise)
5366
.Run();
67+
68+
bottom_grad_trans = bottom_grad_trans.permute({0, 3, 1, 2}).contiguous();
69+
bottom_grad.copy_(bottom_grad_trans);
5470
}
5571

5672
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,

0 commit comments

Comments
 (0)