@@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel {
3030 auto ids_t = context.Input <Tensor>(" Ids" ); // int tensor
3131 auto output_t = context.Output <Tensor>(" Out" ); // float tensor
3232
33- size_t N = table_t ->dims ()[0 ];
34- size_t D = table_t ->dims ()[1 ];
33+ int N = table_t ->dims ()[0 ];
34+ int D = table_t ->dims ()[1 ];
3535 auto ids = ids_t ->data <int32_t >();
3636 auto table = table_t ->data <T>();
3737 auto output = output_t ->mutable_data <T>(context.GetPlace ());
38- for (size_t i = 0 ; i < product (ids_t ->dims ()); ++i) {
38+ for (ssize_t i = 0 ; i < product (ids_t ->dims ()); ++i) {
3939 PADDLE_ENFORCE_LT (ids[i], N);
4040 PADDLE_ENFORCE_GE (ids[i], 0 );
4141 memcpy (output + i * D, table + ids[i] * D, D * sizeof (T));
@@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel {
5151 auto d_output_t = context.Input <Tensor>(framework::GradVarName (" Out" ));
5252 auto d_table_t = context.Output <Tensor>(framework::GradVarName (" W" ));
5353
54- size_t N = d_table_t ->dims ()[0 ];
55- size_t D = d_table_t ->dims ()[1 ];
54+ int N = d_table_t ->dims ()[0 ];
55+ int D = d_table_t ->dims ()[1 ];
5656 auto ids = ids_t ->data <int32_t >();
5757 const T* d_output = d_output_t ->data <T>();
5858 T* d_table = d_table_t ->mutable_data <T>(context.GetPlace ());
@@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel {
6161 t.device (context.GetEigenDevice <platform::CPUPlace>()) =
6262 t.constant (static_cast <T>(0 ));
6363
64- for (size_t i = 0 ; i < product (ids_t ->dims ()); ++i) {
64+ for (ssize_t i = 0 ; i < product (ids_t ->dims ()); ++i) {
6565 PADDLE_ENFORCE_LT (ids[i], N);
6666 PADDLE_ENFORCE_GE (ids[i], 0 );
67- for (size_t j = 0 ; j < D; ++j) {
67+ for (int j = 0 ; j < D; ++j) {
6868 d_table[ids[i] * D + j] += d_output[i * D + j];
6969 }
7070 }
0 commit comments