Skip to content

Commit 01a44c0

Browse files
zytx121irexycxizihadoop-basecvlzhangzz
authored
Add roi_align_rotated op for onnxruntime (#277)
* init * add doc * add * Update test_ops.py * fix bug * fix pose demo and windows build (#307) * add postprocessing_masks gpu version (#276) * add postprocessing_masks gpu version * default device cpu * pre-commit fix Co-authored-by: hadoop-basecv <[email protected]> * fixed a bug causes text-recognizer to fail when (non-NULL) empty bboxes list is passed (#310) * [Fix] include missing <type_traits> for formatter.h (#313) * fix formatter * relax GCC version requirement * fix lint * Update onnxruntime.md * fix lint Co-authored-by: Chen Xin <[email protected]> Co-authored-by: Shengxi Li <[email protected]> Co-authored-by: hadoop-basecv <[email protected]> Co-authored-by: lzhangzz <[email protected]>
1 parent aa536ec commit 01a44c0

File tree

7 files changed

+437
-1
lines changed

7 files changed

+437
-1
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
// Modified from
3+
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated
4+
#include "roi_align_rotated.h"
5+
6+
#include "ort_utils.h"
7+
8+
namespace mmdeploy {
9+
// implementation taken from Caffe2
10+
struct PreCalc {
11+
int pos1;
12+
int pos2;
13+
int pos3;
14+
int pos4;
15+
float w1;
16+
float w2;
17+
float w3;
18+
float w4;
19+
};
20+
21+
void pre_calc_for_bilinear_interpolate(const int height, const int width, const int pooled_height,
22+
const int pooled_width, const int iy_upper,
23+
const int ix_upper, float roi_start_h, float roi_start_w,
24+
float bin_size_h, float bin_size_w, int roi_bin_grid_h,
25+
int roi_bin_grid_w, float roi_center_h, float roi_center_w,
26+
float cos_theta, float sin_theta,
27+
std::vector<PreCalc> &pre_calc) {
28+
int pre_calc_index = 0;
29+
for (int ph = 0; ph < pooled_height; ph++) {
30+
for (int pw = 0; pw < pooled_width; pw++) {
31+
for (int iy = 0; iy < iy_upper; iy++) {
32+
const float yy = roi_start_h + ph * bin_size_h +
33+
static_cast<float>(iy + .5f) * bin_size_h /
34+
static_cast<float>(roi_bin_grid_h); // e.g., 0.5, 1.5
35+
for (int ix = 0; ix < ix_upper; ix++) {
36+
const float xx =
37+
roi_start_w + pw * bin_size_w +
38+
static_cast<float>(ix + .5f) * bin_size_w / static_cast<float>(roi_bin_grid_w);
39+
40+
// Rotate by theta around the center and translate
41+
// In image space, (y, x) is the order for Right Handed System,
42+
// and this is essentially multiplying the point by a rotation matrix
43+
// to rotate it counterclockwise through angle theta.
44+
float y = yy * cos_theta - xx * sin_theta + roi_center_h;
45+
float x = yy * sin_theta + xx * cos_theta + roi_center_w;
46+
// deal with: inverse elements are out of feature map boundary
47+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
48+
// empty
49+
PreCalc pc;
50+
pc.pos1 = 0;
51+
pc.pos2 = 0;
52+
pc.pos3 = 0;
53+
pc.pos4 = 0;
54+
pc.w1 = 0;
55+
pc.w2 = 0;
56+
pc.w3 = 0;
57+
pc.w4 = 0;
58+
pre_calc[pre_calc_index] = pc;
59+
pre_calc_index += 1;
60+
continue;
61+
}
62+
63+
if (y < 0) {
64+
y = 0;
65+
}
66+
if (x < 0) {
67+
x = 0;
68+
}
69+
70+
int y_low = (int)y;
71+
int x_low = (int)x;
72+
int y_high;
73+
int x_high;
74+
75+
if (y_low >= height - 1) {
76+
y_high = y_low = height - 1;
77+
y = (float)y_low;
78+
} else {
79+
y_high = y_low + 1;
80+
}
81+
82+
if (x_low >= width - 1) {
83+
x_high = x_low = width - 1;
84+
x = (float)x_low;
85+
} else {
86+
x_high = x_low + 1;
87+
}
88+
89+
float ly = y - y_low;
90+
float lx = x - x_low;
91+
float hy = 1. - ly, hx = 1. - lx;
92+
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
93+
94+
// save weights and indices
95+
PreCalc pc;
96+
pc.pos1 = y_low * width + x_low;
97+
pc.pos2 = y_low * width + x_high;
98+
pc.pos3 = y_high * width + x_low;
99+
pc.pos4 = y_high * width + x_high;
100+
pc.w1 = w1;
101+
pc.w2 = w2;
102+
pc.w3 = w3;
103+
pc.w4 = w4;
104+
pre_calc[pre_calc_index] = pc;
105+
106+
pre_calc_index += 1;
107+
}
108+
}
109+
}
110+
}
111+
}
112+
113+
void ROIAlignRotatedForwardCPU(const int nthreads, const float *input, const float *rois,
114+
float *output, const float &spatial_scale, const int aligned,
115+
const int clockwise, const int channels, const int height,
116+
const int width, const int pooled_height, const int pooled_width,
117+
const int sampling_ratio) {
118+
int n_rois = nthreads / channels / pooled_width / pooled_height;
119+
// (n, c, ph, pw) is an element in the pooled output
120+
// can be parallelized using omp
121+
// #pragma omp parallel for num_threads(32)
122+
for (int n = 0; n < n_rois; n++) {
123+
int index_n = n * channels * pooled_width * pooled_height;
124+
125+
const float *current_roi = rois + n * 6;
126+
int roi_batch_ind = current_roi[0];
127+
128+
// Do not use rounding; this implementation detail is critical
129+
float offset = aligned ? (float)0.5 : (float)0.0;
130+
float roi_center_w = current_roi[1] * spatial_scale - offset;
131+
float roi_center_h = current_roi[2] * spatial_scale - offset;
132+
float roi_width = current_roi[3] * spatial_scale;
133+
float roi_height = current_roi[4] * spatial_scale;
134+
// float theta = current_roi[5] * M_PI / 180.0;
135+
float theta = current_roi[5]; // Radian angle by default
136+
if (clockwise) {
137+
theta = -theta;
138+
}
139+
float cos_theta = cos(theta);
140+
float sin_theta = sin(theta);
141+
if (!aligned) { // for backward-compatibility only
142+
roi_width = std::max(roi_width, (float)1.);
143+
roi_height = std::max(roi_height, (float)1.);
144+
}
145+
146+
float bin_size_h = static_cast<float>(roi_height) / static_cast<float>(pooled_height);
147+
float bin_size_w = static_cast<float>(roi_width) / static_cast<float>(pooled_width);
148+
149+
// We use roi_bin_grid to sample the grid and mimic integral
150+
int roi_bin_grid_h =
151+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
152+
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
153+
154+
// We do average (integral) pooling inside a bin
155+
const float count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
156+
157+
// we want to precalculate indices and weights shared by all channels,
158+
// this is the key point of optimization
159+
std::vector<PreCalc> pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
160+
161+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
162+
// Appropriate translation needs to be applied after.
163+
float roi_start_h = -roi_height / 2.0;
164+
float roi_start_w = -roi_width / 2.0;
165+
166+
pre_calc_for_bilinear_interpolate(height, width, pooled_height, pooled_width, roi_bin_grid_h,
167+
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h,
168+
bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h,
169+
roi_center_w, cos_theta, sin_theta, pre_calc);
170+
171+
for (int c = 0; c < channels; c++) {
172+
int index_n_c = index_n + c * pooled_width * pooled_height;
173+
const float *offset_input = input + (roi_batch_ind * channels + c) * height * width;
174+
int pre_calc_index = 0;
175+
176+
for (int ph = 0; ph < pooled_height; ph++) {
177+
for (int pw = 0; pw < pooled_width; pw++) {
178+
int index = index_n_c + ph * pooled_width + pw;
179+
180+
float output_val = 0.;
181+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
182+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
183+
PreCalc pc = pre_calc[pre_calc_index];
184+
output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] +
185+
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
186+
187+
pre_calc_index += 1;
188+
}
189+
}
190+
output_val /= count;
191+
192+
output[index] = output_val;
193+
} // for pw
194+
} // for ph
195+
} // for c
196+
} // for n
197+
}
198+
199+
void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext *context) {
200+
// Setup inputs
201+
const OrtValue *input_X = ort_.KernelContext_GetInput(context, 0);
202+
const float *X_data = reinterpret_cast<const float *>(ort_.GetTensorData<float>(input_X));
203+
const OrtValue *input_rois = ort_.KernelContext_GetInput(context, 1);
204+
const float *rois =
205+
reinterpret_cast<const float *>(ort_.GetTensorData<const float *>(input_rois));
206+
207+
// Setup output
208+
OrtTensorDimensions out_dimensions(ort_, input_X);
209+
OrtTensorDimensions roi_dimensions(ort_, input_rois);
210+
211+
int batch_size = out_dimensions.data()[0];
212+
int input_channels = out_dimensions.data()[1];
213+
int input_height = out_dimensions.data()[2];
214+
int input_width = out_dimensions.data()[3];
215+
216+
out_dimensions.data()[0] = roi_dimensions.data()[0];
217+
out_dimensions.data()[2] = aligned_height_;
218+
out_dimensions.data()[3] = aligned_width_;
219+
220+
OrtValue *output =
221+
ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size());
222+
float *out = ort_.GetTensorMutableData<float>(output);
223+
OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output);
224+
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
225+
226+
// TODO: forward here
227+
int output_size = out_dimensions.data()[0];
228+
for (auto i = 1; i < out_dimensions.size(); ++i) {
229+
output_size *= out_dimensions.data()[i];
230+
}
231+
ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, aligned_, clockwise_,
232+
input_channels, input_height, input_width, aligned_height_,
233+
aligned_width_, sampling_ratio_);
234+
}
235+
236+
REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoIAlignRotatedCustomOp);
237+
} // namespace mmdeploy
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) OpenMMLab. All rights reserved
2+
#ifndef ONNXRUNTIME_ROI_ALIGN_ROTATED_H
3+
#define ONNXRUNTIME_ROI_ALIGN_ROTATED_H
4+
5+
#include <assert.h>
6+
#include <onnxruntime_cxx_api.h>
7+
8+
#include <cmath>
9+
#include <mutex>
10+
#include <string>
11+
#include <vector>
12+
13+
namespace mmdeploy {
14+
struct MMCVRoIAlignRotatedKernel {
15+
public:
16+
MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) {
17+
aligned_height_ = ort_.KernelInfoGetAttribute<int64_t>(info, "output_height");
18+
aligned_width_ = ort_.KernelInfoGetAttribute<int64_t>(info, "output_width");
19+
sampling_ratio_ = ort_.KernelInfoGetAttribute<int64_t>(info, "sampling_ratio");
20+
spatial_scale_ = ort_.KernelInfoGetAttribute<float>(info, "spatial_scale");
21+
aligned_ = ort_.KernelInfoGetAttribute<int64_t>(info, "aligned");
22+
clockwise_ = ort_.KernelInfoGetAttribute<int64_t>(info, "clockwise");
23+
}
24+
25+
void Compute(OrtKernelContext* context);
26+
27+
private:
28+
Ort::CustomOpApi ort_;
29+
int aligned_height_;
30+
int aligned_width_;
31+
float spatial_scale_;
32+
int sampling_ratio_;
33+
int aligned_;
34+
int clockwise_;
35+
};
36+
37+
struct MMCVRoIAlignRotatedCustomOp
38+
: Ort::CustomOpBase<MMCVRoIAlignRotatedCustomOp, MMCVRoIAlignRotatedKernel> {
39+
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
40+
return new MMCVRoIAlignRotatedKernel(api, info);
41+
}
42+
const char* GetName() const { return "MMCVRoIAlignRotated"; }
43+
44+
size_t GetInputTypeCount() const { return 2; }
45+
ONNXTensorElementDataType GetInputType(size_t) const {
46+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
47+
}
48+
49+
size_t GetOutputTypeCount() const { return 1; }
50+
ONNXTensorElementDataType GetOutputType(size_t) const {
51+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
52+
}
53+
54+
// force cpu
55+
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
56+
};
57+
} // namespace mmdeploy
58+
59+
#endif // ONNXRUNTIME_ROI_ALIGN_ROTATED_H

docs/en/backends/onnxruntime.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ make -j$(nproc)
6060
| [grid_sampler](../ops/onnxruntime.md#grid_sampler) | Y | N | master |
6161
| [MMCVModulatedDeformConv2d](../ops/onnxruntime.md#mmcvmodulateddeformconv2d) | Y | N | master |
6262
| [NMSRotated](../ops/onnxruntime.md#nmsrotated) | Y | N | master |
63+
| [RoIAlignRotated](../ops/onnxruntime.md#roialignrotated) | Y | N | master |
6364

6465
### How to add a new custom op
6566

docs/en/ops/onnxruntime.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
- [Inputs](#inputs-2)
2222
- [Outputs](#outputs-2)
2323
- [Type Constraints](#type-constraints-2)
24+
- [RoIAlignRotated](#roialignrotated)
25+
- [Description](#description-3)
26+
- [Parameters](#parameters-3)
27+
- [Inputs](#inputs-3)
28+
- [Outputs](#outputs-3)
29+
- [Type Constraints](#type-constraints-3)
2430

2531
<!-- TOC -->
2632

@@ -132,3 +138,41 @@ Non Max Suppression for rotated bboxes.
132138
#### Type Constraints
133139

134140
- T:tensor(float32, Linear)
141+
142+
143+
### RoIAlignRotated
144+
145+
#### Description
146+
147+
Perform RoIAlignRotated on output feature, used in bbox_head of most two-stage rotated object detectors.
148+
149+
#### Parameters
150+
151+
| Type | Parameter | Description |
152+
| ------- | ---------------- | ------------------------------------------------------------------------------------------------------------- |
153+
| `int` | `output_height` | height of output roi |
154+
| `int` | `output_width` | width of output roi |
155+
| `float` | `spatial_scale` | used to scale the input boxes |
156+
| `int` | `sampling_ratio` | number of input samples to take for each output sample. `0` means to take samples densely for current models. |
157+
| `int` | `aligned` | If `aligned=0`, use the legacy implementation in MMDetection. Else, align the results more perfectly. |
158+
| `int` | `clockwise` | If True, the angle in each proposal follows a clockwise fashion in image space, otherwise, the angle is counterclockwise. Default: False. |
159+
160+
#### Inputs
161+
162+
<dl>
163+
<dt><tt>input</tt>: T</dt>
164+
<dd>Input feature map; 4D tensor of shape (N, C, H, W), where N is the batch size, C is the numbers of channels, H and W are the height and width of the data.</dd>
165+
<dt><tt>rois</tt>: T</dt>
166+
<dd>RoIs (Regions of Interest) to pool over; 2-D tensor of shape (num_rois, 6) given as [[batch_index, cx, cy, w, h, theta], ...]. The RoIs' coordinates are the coordinate system of input.</dd>
167+
</dl>
168+
169+
#### Outputs
170+
171+
<dl>
172+
<dt><tt>feat</tt>: T</dt>
173+
<dd>RoI pooled output, 4-D tensor of shape (num_rois, C, output_height, output_width). The r-th batch element feat[r-1] is a pooled feature map corresponding to the r-th RoI RoIs[r-1].<dd>
174+
</dl>
175+
176+
#### Type Constraints
177+
178+
- T:tensor(float32)

mmdeploy/mmcv/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from .nms import * # noqa: F401,F403
55
from .nms_rotated import * # noqa: F401,F403
66
from .roi_align import roi_align_default
7+
from .roi_align_rotated import roi_align_rotated_default
78

89
__all__ = [
910
'roi_align_default', 'modulated_deform_conv_default',
10-
'deform_conv_openvino'
11+
'deform_conv_openvino', 'roi_align_rotated_default'
1112
]

0 commit comments

Comments
 (0)