diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 3e06e5caed3179..ed0bcdebe417c8 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -105,9 +105,26 @@ class LookupTableCUDAKernel : public framework::OpKernel { auto *table = table_t->data(); auto *output = output_t->mutable_data(context.GetPlace()); +#ifdef PADDLE_WITH_HIP + dim3 threads(64, 4); +#else dim3 threads(128, 8); +#endif // PADDLE_WITH_HIP + dim3 grids(8, 1); +#ifdef PADDLE_WITH_HIP + if (padding_idx == -1) + LookupTable< + T, 64, 4, 8, + false><<>>( + output, table, ids, N, K, D, padding_idx); + else + LookupTable< + T, 64, 4, 8, + true><<>>( + output, table, ids, N, K, D, padding_idx); +#else if (padding_idx == -1) LookupTable< T, 128, 8, 8, @@ -118,6 +135,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { T, 128, 8, 8, true><<>>( output, table, ids, N, K, D, padding_idx); +#endif // PADDLE_WITH_HIP } }; @@ -185,10 +203,21 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto t = framework::EigenVector::Flatten(*d_table_t); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); +#ifdef PADDLE_WITH_HIP + dim3 threads(64, 4); +#else dim3 threads(128, 8); +#endif // PADDLE_WITH_HIP + dim3 grids(8, 1); + +#ifdef PADDLE_WITH_HIP + LookupTableGrad<<>>( + d_table, d_output, ids, N, K, D); +#else LookupTableGrad<<>>( d_table, d_output, ids, N, K, D); +#endif // PADDLE_WITH_HIP } } };