@@ -11,18 +11,26 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
1111 int64_t aligned_height_64 = aligned_height;
1212 int64_t aligned_width_64 = aligned_width;
1313 int64_t sampling_ratio_64 = sampling_ratio;
14+
15+ at::Tensor input_trans = input.permute ({0 , 2 , 3 , 1 }).contiguous ();
16+ at::Tensor rois_trans = rois.permute ({1 , 0 }).contiguous ();
17+ at::Tensor output_trans = output.permute ({0 , 2 , 3 , 1 }).contiguous ();
18+
1419 OpCommand cmd;
1520 cmd.Name (" RoiAlignRotated" )
16- .Input (input )
17- .Input (rois )
18- .Output (output )
21+ .Input (input_trans )
22+ .Input (rois_trans )
23+ .Output (output_trans )
1924 .Attr (" pooled_h" , aligned_height_64)
2025 .Attr (" pooled_w" , aligned_width_64)
2126 .Attr (" spatial_scale" , spatial_scale)
2227 .Attr (" sampling_ratio" , sampling_ratio_64)
2328 .Attr (" aligned" , aligned)
2429 .Attr (" clockwise" , clockwise)
2530 .Run ();
31+
32+ output_trans = output_trans.permute ({0 , 3 , 1 , 2 }).contiguous ();
33+ output.copy_ (output_trans);
2634}
2735
2836void roi_align_rotated_backward_npu (Tensor top_grad, Tensor rois,
@@ -33,16 +41,21 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
3341 int64_t aligned_height_64 = aligned_height;
3442 int64_t aligned_width_64 = aligned_width;
3543 int64_t sampling_ratio_64 = sampling_ratio;
44+
45+ at::Tensor top_grad_trans = top_grad.permute ({0 , 2 , 3 , 1 }).contiguous ();
46+ at::Tensor rois_trans = rois.permute ({1 , 0 }).contiguous ();
47+ at::Tensor bottom_grad_trans = bottom_grad.permute ({0 , 2 , 3 , 1 }).contiguous ();
48+
3649 c10::SmallVector<int64_t , 8 > y_grad_shape;
37- auto shape = bottom_grad .sizes ();
50+ auto shape = bottom_grad_trans .sizes ();
3851 for (uint64_t i = 0 ; i < shape.size (); i++) {
3952 y_grad_shape.emplace_back (shape[i]);
4053 }
4154 OpCommand cmd;
4255 cmd.Name (" RoiAlignRotatedGrad" )
43- .Input (top_grad )
44- .Input (rois )
45- .Output (bottom_grad )
56+ .Input (top_grad_trans )
57+ .Input (rois_trans )
58+ .Output (bottom_grad_trans )
4659 .Attr (" y_grad_shape" , y_grad_shape)
4760 .Attr (" pooled_h" , aligned_width_64)
4861 .Attr (" pooled_w" , aligned_height_64)
@@ -51,6 +64,9 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
5164 .Attr (" aligned" , aligned)
5265 .Attr (" clockwise" , clockwise)
5366 .Run ();
67+
68+ bottom_grad_trans = bottom_grad_trans.permute ({0 , 3 , 1 , 2 }).contiguous ();
69+ bottom_grad.copy_ (bottom_grad_trans);
5470}
5571
5672void roi_align_rotated_forward_impl (Tensor input, Tensor rois, Tensor output,
0 commit comments