Skip to content

Commit 5af987a

Browse files
committed
fix gather of xpu
1 parent 1df7e0e commit 5af987a

2 files changed

Lines changed: 93 additions & 9 deletions

File tree

lite/kernels/xpu/gather_compute.cc

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,89 @@ void GatherCompute<DataType, IndexType>::Run() {
4646
axis += x_dims.size();
4747
}
4848

49-
int r = xdnn::gather<DataType, IndexType>(
50-
ctx.GetRawContext(),
51-
x->template data<DataType>(),
52-
index->template data<IndexType>(),
53-
out->template mutable_data<DataType>(TARGET(kXPU)),
54-
x_dims,
55-
index->numel(),
56-
axis);
49+
if (param.X->precision() == PrecisionType::kInt64 &&
50+
param.Index->precision() == PrecisionType::kInt64) {
51+
auto* p_index = param.Index->template data<int64_t>();
52+
int size = param.Index->dims().production();
53+
XPUScratchPadGuard indices_xpu_guard_ =
54+
TargetWrapperXPU::MallocScratchPad(size * sizeof(int));
55+
int* indices_int32_device =
56+
reinterpret_cast<int*>(indices_xpu_guard_->addr_);
5757

58-
CHECK_EQ(r, 0);
58+
int r0 = xdnn::cast_v2<int64_t, int32_t>(
59+
ctx.GetRawContext(), p_index, indices_int32_device, index->numel());
60+
CHECK_EQ(r0, 0);
61+
62+
int r1 = xdnn::gather<int64_t, int32_t>(
63+
ctx.GetRawContext(),
64+
x->template data<int64_t>(),
65+
indices_int32_device,
66+
out->template mutable_data<int64_t>(TARGET(kXPU)),
67+
x_dims,
68+
index->numel(),
69+
axis);
70+
CHECK_EQ(r1, 0);
71+
} else if (param.X->precision() == PrecisionType::kInt64 &&
72+
param.Index->precision() == PrecisionType::kInt32) {
73+
int r = xdnn::gather<int64_t, int32_t>(
74+
ctx.GetRawContext(),
75+
x->template data<int64_t>(),
76+
index->template data<int32_t>(),
77+
out->template mutable_data<int64_t>(TARGET(kXPU)),
78+
x_dims,
79+
index->numel(),
80+
axis);
81+
CHECK_EQ(r, 0);
82+
} else if (param.X->precision() == PrecisionType::kInt32 &&
83+
param.Index->precision() == PrecisionType::kInt32) {
84+
int r = xdnn::gather<int32_t, int32_t>(
85+
ctx.GetRawContext(),
86+
x->template data<int32_t>(),
87+
index->template data<int32_t>(),
88+
out->template mutable_data<int32_t>(TARGET(kXPU)),
89+
x_dims,
90+
index->numel(),
91+
axis);
92+
CHECK_EQ(r, 0);
93+
} else if (param.X->precision() == PrecisionType::kInt32 &&
94+
param.Index->precision() == PrecisionType::kInt64) {
95+
int r = xdnn::gather<int32_t, int64_t>(
96+
ctx.GetRawContext(),
97+
x->template data<int32_t>(),
98+
index->template data<int64_t>(),
99+
out->template mutable_data<int32_t>(TARGET(kXPU)),
100+
x_dims,
101+
index->numel(),
102+
axis);
103+
CHECK_EQ(r, 0);
104+
} else if (param.X->precision() == PrecisionType::kFloat &&
105+
param.Index->precision() == PrecisionType::kInt32) {
106+
int r = xdnn::gather<float, int32_t>(
107+
ctx.GetRawContext(),
108+
x->template data<float>(),
109+
index->template data<int32_t>(),
110+
out->template mutable_data<float>(TARGET(kXPU)),
111+
x_dims,
112+
index->numel(),
113+
axis);
114+
CHECK_EQ(r, 0);
115+
} else if (param.X->precision() == PrecisionType::kFloat &&
116+
param.Index->precision() == PrecisionType::kInt64) {
117+
int r = xdnn::gather<float, int64_t>(
118+
ctx.GetRawContext(),
119+
x->template data<float>(),
120+
index->template data<int64_t>(),
121+
out->template mutable_data<float>(TARGET(kXPU)),
122+
x_dims,
123+
index->numel(),
124+
axis);
125+
CHECK_EQ(r, 0);
126+
} else {
127+
LOG(FATAL) << "Unsupported gather op with x dtype: "
128+
<< lite_api::PrecisionToStr(param.X->precision())
129+
<< " and index dtype: "
130+
<< lite_api::PrecisionToStr(param.Index->precision());
131+
}
59132
}
60133

61134
} // namespace xpu
@@ -107,3 +180,12 @@ REGISTER_LITE_KERNEL(
107180
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
108181
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
109182
.Finalize();
183+
REGISTER_LITE_KERNEL(
184+
gather, kXPU, kFloat, kNCHW, GatherXPUInt64Int64, gather_i64_i64)
185+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
186+
.BindInput("Index",
187+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
188+
.BindInput("Axis",
189+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
190+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
191+
.Finalize();

lite/kernels/xpu/gather_compute.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ typedef paddle::lite::kernels::xpu::GatherCompute<float, int64_t>
4646
GatherXPUFloatInt64;
4747
typedef paddle::lite::kernels::xpu::GatherCompute<int64_t, int32_t>
4848
GatherXPUInt64Int32;
49+
typedef paddle::lite::kernels::xpu::GatherCompute<int64_t, int64_t>
50+
GatherXPUInt64Int64;

0 commit comments

Comments
 (0)