@@ -30,53 +30,67 @@ void RoiAlignCompute::Run() {
3030
3131 auto * in = param.X ;
3232 auto * rois = param.ROIs ;
33+ auto * rois_num = param.RoisNum ;
3334 auto * out = param.Out ;
3435 float spatial_scale = param.spatial_scale ;
3536 int pooled_height = param.pooled_height ;
3637 int pooled_width = param.pooled_width ;
3738 int sampling_ratio = param.sampling_ratio ;
39+ bool align = param.align ;
3840
3941 auto in_dims = in->dims ();
4042 int batch_size = in_dims[0 ];
4143 int channels = in_dims[1 ];
4244 int height = in_dims[2 ];
4345 int width = in_dims[3 ];
4446 auto rois_dims = rois->dims ();
45- int rois_num = rois_dims[0 ];
47+ int rois_total = rois_dims[0 ];
4648 auto out_dims = out->dims ();
47- if (rois_num == 0 ) {
49+ if (rois_total == 0 ) {
4850 return ;
4951 }
50- auto rois_lod = rois-> lod (). back () ;
52+ const int * xpu_lod = nullptr ;
5153 std::vector<int > cpu_lod_data;
5254 cpu_lod_data.resize (batch_size + 1 );
53- for (int i = 0 ; i < rois_lod.size (); i++) {
54- cpu_lod_data[i] = rois_lod[i];
55+ if (rois_num == nullptr ) {
56+ auto rois_lod = rois->lod ().back ();
57+ int n = rois_lod.size ();
58+ for (int i = 0 ; i < n; i++) {
59+ cpu_lod_data[i] = rois_lod[i];
60+ }
61+ } else {
62+ auto rois_lod = rois_num->data <int >();
63+ int n = rois_num->numel ();
64+ for (int i = 1 ; i <= n; i++) {
65+ cpu_lod_data[i] = cpu_lod_data[i - 1 ] + rois_lod[i - 1 ];
66+ }
5567 }
56-
68+ CHECK_EQ (cpu_lod_data[batch_size], rois_total);
5769 XPUScratchPadGuard xpu_lod_grad_ =
5870 TargetWrapperXPU::MallocScratchPad (cpu_lod_data.size () * sizeof (int ));
59- int * xpu_lod = reinterpret_cast <int *>(xpu_lod_grad_->addr_ );
60- XPU_CALL (xpu_memcpy (xpu_lod,
71+ XPU_CALL (xpu_memcpy (xpu_lod_grad_->addr_ ,
6172 cpu_lod_data.data (),
6273 sizeof (int ) * cpu_lod_data.size (),
6374 XPUMemcpyKind::XPU_HOST_TO_DEVICE));
75+ xpu_lod = reinterpret_cast <int *>(xpu_lod_grad_->addr_ );
6476
6577 int r = xdnn::roi_align<float , int >(ctx.GetRawContext (),
6678 in->data <float >(),
6779 out->mutable_data <float >(TARGET (kXPU )),
6880 rois->data <float >(),
69- static_cast < const int *>( xpu_lod) ,
81+ xpu_lod,
7082 batch_size,
7183 channels,
7284 height,
7385 width,
74- rois_num ,
86+ rois_total ,
7587 pooled_height,
7688 pooled_width,
7789 spatial_scale,
7890 sampling_ratio,
79- true );
91+ true ,
92+ align,
93+ false );
8094 CHECK_EQ (r, 0 );
8195}
8296
@@ -93,5 +107,7 @@ REGISTER_LITE_KERNEL(roi_align,
93107 def)
94108 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
95109 .BindInput(" ROIs" , {LiteType::GetTensorTy (TARGET (kXPU ))})
110+ .BindInput(" RoisNum" ,
111+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
96112 .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ))})
97113 .Finalize();
0 commit comments