Skip to content

Commit 2ba3617

Browse files
committed
add rois_num for roi_align xpu OP, test=develop
1 parent 9f45e75 commit 2ba3617

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

paddle/fluid/operators/roi_align_op_xpu.cc

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
3939
int width = in_dims[3];
4040
int rois_num = rois->dims()[0];
4141
const T* input_data = in->data<T>();
42-
auto rois_lod = rois->lod().back();
43-
int rois_batch_size = rois_lod.size() - 1;
44-
PADDLE_ENFORCE_EQ(
45-
rois_batch_size, batch_size,
46-
platform::errors::InvalidArgument(
47-
"The rois_batch_size and imgs batch_size of roi_align_xpu OP must "
48-
"be the same. But received rois_batch_size %d , batch_size %d",
49-
rois_batch_size, batch_size));
42+
43+
framework::Tensor _roi_batch_list;
44+
_roi_batch_list.Resize({rois_num});
45+
int* rois_lod = _roi_batch_list.mutable_data<int>(ctx.GetPlace());
46+
int rois_batch_size = 1;
47+
if (ctx.HasInput("RoisNum")) {
48+
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
49+
rois_batch_size = rois_num_t->numel();
50+
PADDLE_ENFORCE_EQ(
51+
rois_batch_size, batch_size,
52+
platform::errors::InvalidArgument(
53+
"The batch size of rois and the batch size of images "
54+
" must be the same. But received the batch size of rois is %d, "
55+
"and the batch size of images is %d",
56+
rois_batch_size, batch_size));
57+
auto* rois_num_data = rois_num_t->data<int>();
58+
rois_lod[0] = 0;
59+
for (int n = 0; n < rois_batch_size; ++n) {
60+
rois_lod[n + 1] = rois_lod[n] + rois_num_data[n];
61+
}
62+
} else {
63+
auto _rois_lod = rois->lod().back();
64+
rois_batch_size = _rois_lod.size() - 1;
65+
for (int n = 0; n < _rois_lod.size(); ++n) {
66+
rois_lod[n] = _rois_lod[n];
67+
}
68+
PADDLE_ENFORCE_EQ(
69+
rois_batch_size, batch_size,
70+
platform::errors::InvalidArgument(
71+
"The rois_batch_size and imgs batch_size of roi_align_xpu OP "
72+
"must "
73+
"be the same. But received rois_batch_size %d , batch_size %d",
74+
rois_batch_size, batch_size));
75+
}
5076
int rois_num_with_lod = rois_lod[rois_batch_size];
5177
PADDLE_ENFORCE_EQ(
5278
rois_num, rois_num_with_lod,

python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,29 @@ def test_check_output(self):
179179
self.check_output_with_place(place)
180180

181181

182+
class TestROIAlignInLodOp(TestROIAlignOp):
183+
def set_data(self):
184+
self.init_test_case()
185+
self.make_rois()
186+
self.calc_roi_align()
187+
188+
seq_len = self.rois_lod[0]
189+
190+
self.inputs = {
191+
'X': self.x,
192+
'ROIs': (self.rois[:, 1:5], self.rois_lod),
193+
'RoisNum': np.asarray(seq_len).astype('int32')
194+
}
195+
196+
self.attrs = {
197+
'spatial_scale': self.spatial_scale,
198+
'pooled_height': self.pooled_height,
199+
'pooled_width': self.pooled_width,
200+
'sampling_ratio': self.sampling_ratio
201+
}
202+
203+
self.outputs = {'Out': self.out_data}
204+
205+
182206
if __name__ == '__main__':
183207
unittest.main()

0 commit comments

Comments
 (0)