@@ -183,7 +183,11 @@ using SubFloat32 =
183183using SubFloat16 = xpu::ElementwiseCompute<float16,
184184 xpu::SubFunctor<float16>,
185185 PRECISION (kFP16 )>;
186-
186+ using SubInt32 =
187+ xpu::ElementwiseCompute<int , xpu::SubFunctor<int >, PRECISION(kFloat )>;
188+ using SubInt64 = xpu::ElementwiseCompute<int64_t ,
189+ xpu::SubFunctor<int64_t >,
190+ PRECISION (kFloat )>;
187191using MulFloat32 =
188192 xpu::ElementwiseCompute<float , xpu::MulFunctor<float >, PRECISION(kFloat )>;
189193using MulFloat16 = xpu::ElementwiseCompute<float16,
@@ -273,6 +277,18 @@ REGISTER_LITE_KERNEL(
273277 .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
274278 .Finalize();
275279
280+ REGISTER_LITE_KERNEL (elementwise_sub, kXPU , kFloat , kNCHW , SubInt32, int32)
281+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt32 ))})
282+ .BindInput(" Y" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt32 ))})
283+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt32 ))})
284+ .Finalize();
285+
286+ REGISTER_LITE_KERNEL (elementwise_sub, kXPU , kFloat , kNCHW , SubInt64, int64)
287+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
288+ .BindInput(" Y" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
289+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kInt64 ))})
290+ .Finalize();
291+
276292REGISTER_LITE_KERNEL (elementwise_mul, kXPU , kFloat , kNCHW , MulFloat32, def)
277293 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
278294 .BindInput(" Y" , {LiteType::GetTensorTy (TARGET (kXPU ))})
0 commit comments