@@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
3939 int width = in_dims[3 ];
4040 int rois_num = rois->dims ()[0 ];
4141 const T* input_data = in->data <T>();
42- auto rois_lod = rois->lod ().back ();
43- int rois_batch_size = rois_lod.size () - 1 ;
44- PADDLE_ENFORCE_EQ (
45- rois_batch_size, batch_size,
46- platform::errors::InvalidArgument (
47- " The rois_batch_size and imgs batch_size of roi_align_xpu OP must "
48- " be the same. But received rois_batch_size %d , batch_size %d" ,
49- rois_batch_size, batch_size));
42+
43+ framework::Tensor _roi_batch_list;
44+ _roi_batch_list.Resize ({rois_num});
45+ int * rois_lod = _roi_batch_list.mutable_data <int >(ctx.GetPlace ());
46+ int rois_batch_size = 1 ;
47+ if (ctx.HasInput (" RoisNum" )) {
48+ auto * rois_num_t = ctx.Input <framework::Tensor>(" RoisNum" );
49+ rois_batch_size = rois_num_t ->numel ();
50+ PADDLE_ENFORCE_EQ (
51+ rois_batch_size, batch_size,
52+ platform::errors::InvalidArgument (
53+ " The batch size of rois and the batch size of images "
54+ " must be the same. But received the batch size of rois is %d, "
55+ " and the batch size of images is %d" ,
56+ rois_batch_size, batch_size));
57+ auto * rois_num_data = rois_num_t ->data <int >();
58+ rois_lod[0 ] = 0 ;
59+ for (int n = 0 ; n < rois_batch_size; ++n) {
60+ rois_lod[n + 1 ] = rois_lod[n] + rois_num_data[n];
61+ }
62+ } else {
63+ auto _rois_lod = rois->lod ().back ();
64+ rois_batch_size = _rois_lod.size () - 1 ;
65+ for (int n = 0 ; n < _rois_lod.size (); ++n) {
66+ rois_lod[n] = _rois_lod[n];
67+ }
68+ PADDLE_ENFORCE_EQ (
69+ rois_batch_size, batch_size,
70+ platform::errors::InvalidArgument (
71+ " The rois_batch_size and imgs batch_size of roi_align_xpu OP "
72+ " must "
73+ " be the same. But received rois_batch_size %d , batch_size %d" ,
74+ rois_batch_size, batch_size));
75+ }
5076 int rois_num_with_lod = rois_lod[rois_batch_size];
5177 PADDLE_ENFORCE_EQ (
5278 rois_num, rois_num_with_lod,
0 commit comments