Skip to content

Commit ad0dc8f

Browse files
authored
Merge pull request #3897 from Canpio/fix_warnings_in_lookup_op
Fix compile warnings in lookup_op
2 parents b64aac5 + fd0e1e8 commit ad0dc8f

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

paddle/operators/lookup_table_op.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel {
3030
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
3131
auto output_t = context.Output<Tensor>("Out"); // float tensor
3232

33-
size_t N = table_t->dims()[0];
34-
size_t D = table_t->dims()[1];
33+
int N = table_t->dims()[0];
34+
int D = table_t->dims()[1];
3535
auto ids = ids_t->data<int32_t>();
3636
auto table = table_t->data<T>();
3737
auto output = output_t->mutable_data<T>(context.GetPlace());
38-
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
38+
for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
3939
PADDLE_ENFORCE_LT(ids[i], N);
4040
PADDLE_ENFORCE_GE(ids[i], 0);
4141
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
@@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel {
5151
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
5252
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
5353

54-
size_t N = d_table_t->dims()[0];
55-
size_t D = d_table_t->dims()[1];
54+
int N = d_table_t->dims()[0];
55+
int D = d_table_t->dims()[1];
5656
auto ids = ids_t->data<int32_t>();
5757
const T* d_output = d_output_t->data<T>();
5858
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
@@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel {
6161
t.device(context.GetEigenDevice<platform::CPUPlace>()) =
6262
t.constant(static_cast<T>(0));
6363

64-
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
64+
for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
6565
PADDLE_ENFORCE_LT(ids[i], N);
6666
PADDLE_ENFORCE_GE(ids[i], 0);
67-
for (size_t j = 0; j < D; ++j) {
67+
for (int j = 0; j < D; ++j) {
6868
d_table[ids[i] * D + j] += d_output[i * D + j];
6969
}
7070
}

0 commit comments

Comments
 (0)