@@ -70,10 +70,10 @@ void LookupTableCompute<T_W, T_IDS>::Run() {
7070 memcpy (dout + i * row_width, table_data, row_width * sizeof (float ));
7171 }
7272#else
73- auto table_data = w->template data <float >();
73+ auto table_data = w->template data <T_W >();
7474 memcpy (dout + i * row_width,
7575 table_data + ids_int * row_width,
76- row_width * sizeof (float ));
76+ row_width * sizeof (T_W ));
7777#endif
7878 }
7979 }
@@ -87,6 +87,8 @@ void LookupTableCompute<T_W, T_IDS>::Run() {
8787
8888using LookupTableFloatInt64 =
8989 paddle::lite::kernels::arm::LookupTableCompute<float , int64_t >;
90+ using LookupTableInt8Int64 =
91+ paddle::lite::kernels::arm::LookupTableCompute<int8_t , int64_t >;
9092using LookupTableFloatInt32 =
9193 paddle::lite::kernels::arm::LookupTableCompute<float , int32_t >;
9294
@@ -105,6 +107,14 @@ REGISTER_LITE_KERNEL(
105107 .BindPaddleOpVersion(" lookup_table_v2" , 1 )
106108 .Finalize();
107109
110+ REGISTER_LITE_KERNEL (
111+ lookup_table_v2, kARM , kAny , kNCHW , LookupTableInt8Int64, int8_int64)
112+ .BindInput(" W" , {LiteType::GetTensorTy (TARGET (kARM ), PRECISION (kInt8 ))})
113+ .BindInput(" Ids" , {LiteType::GetTensorTy (TARGET (kARM ), PRECISION (kInt64 ))})
114+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kARM ), PRECISION (kInt8 ))})
115+ .BindPaddleOpVersion(" lookup_table_v2" , 1 )
116+ .Finalize();
117+
108118REGISTER_LITE_KERNEL (
109119 lookup_table, kARM , kAny , kNCHW , LookupTableFloatInt32, float_int32)
110120 .BindInput(" W" , {LiteType::GetTensorTy (TARGET (kARM ))})
0 commit comments