Skip to content

Commit 608b813

Browse files
authored
[XPU] Add lod support to roi_align op. (#9274)
1 parent d978321 commit 608b813

2 files changed

Lines changed: 29 additions & 11 deletions

File tree

lite/kernels/xpu/roi_align_compute.cc

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,53 +30,67 @@ void RoiAlignCompute::Run() {
3030

3131
auto* in = param.X;
3232
auto* rois = param.ROIs;
33+
auto* rois_num = param.RoisNum;
3334
auto* out = param.Out;
3435
float spatial_scale = param.spatial_scale;
3536
int pooled_height = param.pooled_height;
3637
int pooled_width = param.pooled_width;
3738
int sampling_ratio = param.sampling_ratio;
39+
bool align = param.align;
3840

3941
auto in_dims = in->dims();
4042
int batch_size = in_dims[0];
4143
int channels = in_dims[1];
4244
int height = in_dims[2];
4345
int width = in_dims[3];
4446
auto rois_dims = rois->dims();
45-
int rois_num = rois_dims[0];
47+
int rois_total = rois_dims[0];
4648
auto out_dims = out->dims();
47-
if (rois_num == 0) {
49+
if (rois_total == 0) {
4850
return;
4951
}
50-
auto rois_lod = rois->lod().back();
52+
const int* xpu_lod = nullptr;
5153
std::vector<int> cpu_lod_data;
5254
cpu_lod_data.resize(batch_size + 1);
53-
for (int i = 0; i < rois_lod.size(); i++) {
54-
cpu_lod_data[i] = rois_lod[i];
55+
if (rois_num == nullptr) {
56+
auto rois_lod = rois->lod().back();
57+
int n = rois_lod.size();
58+
for (int i = 0; i < n; i++) {
59+
cpu_lod_data[i] = rois_lod[i];
60+
}
61+
} else {
62+
auto rois_lod = rois_num->data<int>();
63+
int n = rois_num->numel();
64+
for (int i = 1; i <= n; i++) {
65+
cpu_lod_data[i] = cpu_lod_data[i - 1] + rois_lod[i - 1];
66+
}
5567
}
56-
68+
CHECK_EQ(cpu_lod_data[batch_size], rois_total);
5769
XPUScratchPadGuard xpu_lod_grad_ =
5870
TargetWrapperXPU::MallocScratchPad(cpu_lod_data.size() * sizeof(int));
59-
int* xpu_lod = reinterpret_cast<int*>(xpu_lod_grad_->addr_);
60-
XPU_CALL(xpu_memcpy(xpu_lod,
71+
XPU_CALL(xpu_memcpy(xpu_lod_grad_->addr_,
6172
cpu_lod_data.data(),
6273
sizeof(int) * cpu_lod_data.size(),
6374
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
75+
xpu_lod = reinterpret_cast<int*>(xpu_lod_grad_->addr_);
6476

6577
int r = xdnn::roi_align<float, int>(ctx.GetRawContext(),
6678
in->data<float>(),
6779
out->mutable_data<float>(TARGET(kXPU)),
6880
rois->data<float>(),
69-
static_cast<const int*>(xpu_lod),
81+
xpu_lod,
7082
batch_size,
7183
channels,
7284
height,
7385
width,
74-
rois_num,
86+
rois_total,
7587
pooled_height,
7688
pooled_width,
7789
spatial_scale,
7890
sampling_ratio,
79-
true);
91+
true,
92+
align,
93+
false);
8094
CHECK_EQ(r, 0);
8195
}
8296

@@ -93,5 +107,7 @@ REGISTER_LITE_KERNEL(roi_align,
93107
def)
94108
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
95109
.BindInput("ROIs", {LiteType::GetTensorTy(TARGET(kXPU))})
110+
.BindInput("RoisNum",
111+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
96112
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
97113
.Finalize();

lite/tests/kernels/roi_align_compute_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ TEST(RoiAlign, precision) {
374374
#else
375375
return;
376376
#endif
377+
#elif defined(LITE_WITH_XPU)
378+
place = TARGET(kXPU);
377379
#elif defined(LITE_WITH_X86) || defined(LITE_WITH_ARM)
378380
place = TARGET(kHost);
379381
#else

0 commit comments

Comments
 (0)