@@ -28,6 +28,12 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
2828 auto *ids_t = ctx.Input <framework::LoDTensor>(" Ids" ); // int tensor
2929 auto *output_t = ctx.Output <framework::LoDTensor>(" Out" ); // float tensor
3030 auto *table_t = ctx.Input <framework::LoDTensor>(" W" );
31+
32+ // It seems cann 20.1 accepts int64, but cann 20.2+ not.
33+ PADDLE_ENFORCE_EQ (ids_t ->type (), framework::proto::VarType::INT32,
34+ platform::errors::Unimplemented (
35+ " The index of LookupTableV2 should be int32." ));
36+
3137 auto *table_var = ctx.InputVar (" W" );
3238 PADDLE_ENFORCE_EQ (
3339 table_var->IsType <framework::LoDTensor>(), true ,
@@ -49,28 +55,26 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
4955 public:
5056 void Compute (const framework::ExecutionContext &ctx) const override {
5157 auto *ids_t = ctx.Input <framework::LoDTensor>(" Ids" );
58+
5259 auto *output_grad_t =
5360 ctx.Input <framework::LoDTensor>(framework::GradVarName (" Out" ));
5461 auto *table_grad_t =
5562 ctx.Output <framework::LoDTensor>(framework::GradVarName (" W" ));
56- table_grad_t ->mutable_data <T>(ctx.GetPlace ());
63+ auto *p = table_grad_t ->mutable_data <T>(ctx.GetPlace ());
5764
5865 auto stream =
5966 ctx.template device_context <paddle::platform::NPUDeviceContext>()
6067 .stream ();
6168
62- // step2: ZerosLike x in device
63- Tensor zeroslike_w (table_grad_t ->type ());
64- zeroslike_w.Resize (table_grad_t ->dims ());
65- auto p = zeroslike_w.mutable_data <T>(ctx.GetPlace ());
66-
6769 platform::NPUMemsetAsync (static_cast <void *>(p), 0 ,
68- zeroslike_w. numel () * sizeof (T), stream);
70+ table_grad_t -> numel () * sizeof (T), stream);
6971
70- table_grad_t ->mutable_data <T>(ctx.GetPlace ());
72+ // NOTE(zhiqiu): It seems in cann 20.1, the first input and output
73+ // can be different tensor, but in cann 20.2+, it does inplace operation.
74+ // Thus, the first input and output should be same tensor.
7175 auto runner_scatter =
72- NpuOpRunner (" ScatterAdd" , {zeroslike_w , *ids_t , *output_grad_t },
73- {*table_grad_t }, {});
76+ NpuOpRunner (" ScatterAdd" , {* table_grad_t , *ids_t , *output_grad_t },
77+ {*table_grad_t }, {{ " use_locking " , true } });
7478 runner_scatter.Run (stream);
7579 }
7680};
0 commit comments