@@ -185,34 +185,67 @@ class GPUPSROIPoolOpKernel : public framework::OpKernel<T> {
185185
186186 int rois_num = rois->dims ()[0 ];
187187 if (rois_num == 0 ) return ;
188-
189- auto rois_lod = rois->lod ().back ();
190- int rois_batch_size = rois_lod.size () - 1 ;
191- PADDLE_ENFORCE_EQ (rois_batch_size, batch_size,
192- platform::errors::InvalidArgument (
193- " The batch size of input(ROIs) and input(X) must be "
194- " the same but received batch size of input(ROIs) and "
195- " input(X) is %d and %d respectively." ,
196- rois_batch_size, batch_size));
197- int rois_num_with_lod = rois_lod[rois_batch_size];
198- PADDLE_ENFORCE_EQ (rois_num, rois_num_with_lod,
199- platform::errors::InvalidArgument (
200- " The number of rois from input(ROIs) and its LOD "
201- " must be the same. Received rois %d of input(ROIs) "
202- " but the number of rois %d from its LOD is %d" ,
203- rois_num, rois_num_with_lod));
204-
205- // set rois batch id
188+ int rois_batch_size;
206189 framework::Tensor rois_batch_id_list;
207190 rois_batch_id_list.Resize ({rois_num});
208191 int * rois_batch_id_data =
209192 rois_batch_id_list.mutable_data <int >(platform::CPUPlace ());
210- for (int n = 0 ; n < rois_batch_size; ++n) {
211- for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
212- rois_batch_id_data[i] = n;
193+
194+ if (ctx.HasInput (" RoisNum" )) {
195+ auto * rois_num_t = ctx.Input <Tensor>(" RoisNum" );
196+ rois_batch_size = rois_num_t ->numel ();
197+ auto * rois_num_data = rois_num_t ->data <int >();
198+ PADDLE_ENFORCE_EQ (
199+ rois_batch_size, batch_size,
200+ platform::errors::InvalidArgument (
201+ " The batch size of input(ROIs) and input(X) must be "
202+ " the same but received batch size of input(ROIs) and "
203+ " input(X) is %d and %d respectively." ,
204+ rois_batch_size, batch_size));
205+ std::vector<int > rois_num_list (rois_batch_size);
206+ memory::Copy (platform::CPUPlace (), rois_num_list.data (),
207+ BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ()),
208+ rois_num_data, sizeof (int ) * rois_batch_size, 0 );
209+ int rois_num_count = 0 ;
210+ for (int i = 0 ; i < rois_batch_size; ++i) {
211+ rois_num_count += rois_num_list[i];
212+ }
213+ PADDLE_ENFORCE_EQ (
214+ rois_num_count, rois_num,
215+ platform::errors::InvalidArgument (
216+ " the rois_num from input and RoisNum must be the same" ));
217+ int start = 0 ;
218+ for (int n = 0 ; n < rois_batch_size; ++n) {
219+ for (int i = start; i < start + rois_num_list[n]; ++i) {
220+ rois_batch_id_data[i] = n;
221+ }
222+ start += rois_num_list[n];
223+ }
224+ } else {
225+ auto rois_lod = rois->lod ().back ();
226+ rois_batch_size = rois_lod.size () - 1 ;
227+ PADDLE_ENFORCE_EQ (
228+ rois_batch_size, batch_size,
229+ platform::errors::InvalidArgument (
230+ " The batch size of input(ROIs) and input(X) must be "
231+ " the same but received batch size of input(ROIs) and "
232+ " input(X) is %d and %d respectively." ,
233+ rois_batch_size, batch_size));
234+ int rois_num_with_lod = rois_lod[rois_batch_size];
235+ PADDLE_ENFORCE_EQ (rois_num, rois_num_with_lod,
236+ platform::errors::InvalidArgument (
237+ " The number of rois from input(ROIs) and its LOD "
238+ " must be the same. Received rois %d of input(ROIs) "
239+ " but the number of rois %d from its LOD is %d" ,
240+ rois_num, rois_num_with_lod));
241+
242+ // set rois batch id
243+ for (int n = 0 ; n < rois_batch_size; ++n) {
244+ for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
245+ rois_batch_id_data[i] = n;
246+ }
213247 }
214248 }
215-
216249 framework::Tensor rois_batch_id_list_gpu;
217250 framework::TensorCopy (rois_batch_id_list, ctx.GetPlace (),
218251 ctx.device_context (), &rois_batch_id_list_gpu);
@@ -257,14 +290,30 @@ class GPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
257290 rois_batch_id_list.Resize ({rois_num});
258291 int * rois_batch_id_data =
259292 rois_batch_id_list.mutable_data <int >(platform::CPUPlace ());
260- auto rois_lod = rois->lod ().back ();
261- int rois_batch_size = rois_lod.size () - 1 ;
262- for (int n = 0 ; n < rois_batch_size; ++n) {
263- for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
264- rois_batch_id_data[i] = n;
293+ int rois_batch_size;
294+ if (ctx.HasInput (" RoisNum" )) {
295+ auto * rois_num_t = ctx.Input <Tensor>(" RoisNum" );
296+ rois_batch_size = rois_num_t ->numel ();
297+ std::vector<int > rois_num_list (rois_batch_size);
298+ memory::Copy (platform::CPUPlace (), rois_num_list.data (),
299+ BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ()),
300+ rois_num_t ->data <int >(), sizeof (int ) * rois_batch_size, 0 );
301+ int start = 0 ;
302+ for (int n = 0 ; n < rois_batch_size; ++n) {
303+ for (int i = start; i < start + rois_num_list[n]; ++i) {
304+ rois_batch_id_data[i] = n;
305+ }
306+ start += rois_num_list[n];
307+ }
308+ } else {
309+ auto rois_lod = rois->lod ().back ();
310+ rois_batch_size = rois_lod.size () - 1 ;
311+ for (int n = 0 ; n < rois_batch_size; ++n) {
312+ for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
313+ rois_batch_id_data[i] = n;
314+ }
265315 }
266316 }
267-
268317 framework::Tensor rois_batch_id_list_gpu;
269318 framework::TensorCopy (rois_batch_id_list, ctx.GetPlace (),
270319 ctx.device_context (), &rois_batch_id_list_gpu);
0 commit comments