@@ -46,16 +46,89 @@ void GatherCompute<DataType, IndexType>::Run() {
4646 axis += x_dims.size ();
4747 }
4848
49- int r = xdnn::gather<DataType, IndexType>(
50- ctx. GetRawContext (),
51- x ->template data <DataType >(),
52- index-> template data <IndexType>(),
53- out-> template mutable_data <DataType>( TARGET ( kXPU )),
54- x_dims,
55- index-> numel (),
56- axis );
49+ if (param. X -> precision () == PrecisionType:: kInt64 &&
50+ param. Index -> precision () == PrecisionType:: kInt64 ) {
51+ auto * p_index = param. Index ->template data <int64_t >();
52+ int size = param. Index -> dims (). production ();
53+ XPUScratchPadGuard indices_xpu_guard_ =
54+ TargetWrapperXPU::MallocScratchPad (size * sizeof ( int ));
55+ int * indices_int32_device =
56+ reinterpret_cast < int *>(indices_xpu_guard_-> addr_ );
5757
58- CHECK_EQ (r, 0 );
58+ int r0 = xdnn::cast_v2<int64_t , int32_t >(
59+ ctx.GetRawContext (), p_index, indices_int32_device, index->numel ());
60+ CHECK_EQ (r0, 0 );
61+
62+ int r1 = xdnn::gather<int64_t , int32_t >(
63+ ctx.GetRawContext (),
64+ x->template data <int64_t >(),
65+ indices_int32_device,
66+ out->template mutable_data <int64_t >(TARGET (kXPU )),
67+ x_dims,
68+ index->numel (),
69+ axis);
70+ CHECK_EQ (r1, 0 );
71+ } else if (param.X ->precision () == PrecisionType::kInt64 &&
72+ param.Index ->precision () == PrecisionType::kInt32 ) {
73+ int r = xdnn::gather<int64_t , int32_t >(
74+ ctx.GetRawContext (),
75+ x->template data <int64_t >(),
76+ index->template data <int32_t >(),
77+ out->template mutable_data <int64_t >(TARGET (kXPU )),
78+ x_dims,
79+ index->numel (),
80+ axis);
81+ CHECK_EQ (r, 0 );
82+ } else if (param.X ->precision () == PrecisionType::kInt32 &&
83+ param.Index ->precision () == PrecisionType::kInt32 ) {
84+ int r = xdnn::gather<int32_t , int32_t >(
85+ ctx.GetRawContext (),
86+ x->template data <int32_t >(),
87+ index->template data <int32_t >(),
88+ out->template mutable_data <int32_t >(TARGET (kXPU )),
89+ x_dims,
90+ index->numel (),
91+ axis);
92+ CHECK_EQ (r, 0 );
93+ } else if (param.X ->precision () == PrecisionType::kInt32 &&
94+ param.Index ->precision () == PrecisionType::kInt64 ) {
95+ int r = xdnn::gather<int32_t , int64_t >(
96+ ctx.GetRawContext (),
97+ x->template data <int32_t >(),
98+ index->template data <int64_t >(),
99+ out->template mutable_data <int32_t >(TARGET (kXPU )),
100+ x_dims,
101+ index->numel (),
102+ axis);
103+ CHECK_EQ (r, 0 );
104+ } else if (param.X ->precision () == PrecisionType::kFloat &&
105+ param.Index ->precision () == PrecisionType::kInt32 ) {
106+ int r = xdnn::gather<float , int32_t >(
107+ ctx.GetRawContext (),
108+ x->template data <float >(),
109+ index->template data <int32_t >(),
110+ out->template mutable_data <float >(TARGET (kXPU )),
111+ x_dims,
112+ index->numel (),
113+ axis);
114+ CHECK_EQ (r, 0 );
115+ } else if (param.X ->precision () == PrecisionType::kFloat &&
116+ param.Index ->precision () == PrecisionType::kInt64 ) {
117+ int r = xdnn::gather<float , int64_t >(
118+ ctx.GetRawContext (),
119+ x->template data <float >(),
120+ index->template data <int64_t >(),
121+ out->template mutable_data <float >(TARGET (kXPU )),
122+ x_dims,
123+ index->numel (),
124+ axis);
125+ CHECK_EQ (r, 0 );
126+ } else {
127+ LOG (FATAL) << " Unsupported gather op with x dtype: "
128+ << lite_api::PrecisionToStr (param.X ->precision ())
129+ << " and index dtype: "
130+ << lite_api::PrecisionToStr (param.Index ->precision ());
131+ }
59132}
60133
61134} // namespace xpu
@@ -107,3 +180,12 @@ REGISTER_LITE_KERNEL(
107180 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
108181 .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
109182 .Finalize();
183+ REGISTER_LITE_KERNEL (
184+ gather, kXPU , kFloat , kNCHW , GatherXPUInt64Int64, gather_i64_i64)
185+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
186+ .BindInput(" Index" ,
187+ {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
188+ .BindInput(" Axis" ,
189+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
190+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
191+ .Finalize();
0 commit comments