@@ -65,17 +65,31 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
6565 ctx.template device_context <paddle::platform::NPUDeviceContext>()
6666 .stream ();
6767
68- const auto &runner_zeros =
69- NpuOpRunner (" ZerosLike" , {*table_grad_t }, {*table_grad_t });
70- runner_zeros.Run (stream);
71-
72- // NOTE(zhiqiu): It seems in cann 20.1, the first input and output
73- // can be different tensor, but in cann 20.2+, it does inplace operation.
74- // Thus, the first input and output should be same tensor.
75- const auto &runner_scatter =
76- NpuOpRunner (" ScatterAdd" , {*table_grad_t , *ids_t , *output_grad_t },
77- {*table_grad_t }, {{" use_locking" , true }});
78- runner_scatter.Run (stream);
68+ int embedding_dim = table_grad_t ->dims ()[1 ];
69+
70+ if (embedding_dim % 32 == 0 ) {
71+ // NOTE(pangyoki): The embedding_dim of Tensor used in
72+ // EmbeddingDenseGrad must be an integer multiple of 32.
73+ int num_weights = table_grad_t ->dims ()[0 ];
74+ const auto &runner =
75+ NpuOpRunner (" EmbeddingDenseGrad" , {*output_grad_t , *ids_t },
76+ {*table_grad_t }, {{" num_weights" , num_weights},
77+ {" padding_idx" , -1 },
78+ {" scale_grad_by_freq" , false }});
79+ runner.Run (stream);
80+ } else {
81+ const auto &runner_zeros =
82+ NpuOpRunner (" ZerosLike" , {*table_grad_t }, {*table_grad_t });
83+ runner_zeros.Run (stream);
84+
85+ // NOTE(zhiqiu): It seems in cann 20.1, the first input and output
86+ // can be different tensor, but in cann 20.2+, it does inplace operation.
87+ // Thus, the first input and output should be same tensor.
88+ const auto &runner_scatter =
89+ NpuOpRunner (" ScatterAdd" , {*table_grad_t , *ids_t , *output_grad_t },
90+ {*table_grad_t }, {{" use_locking" , true }});
91+ runner_scatter.Run (stream);
92+ }
7993 }
8094};
8195} // namespace operators
0 commit comments