Skip to content

Commit 4d84205

Browse files
authored
[NPU] change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op (#33866)
* change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op * EmbeddingDenseGrad only supports dim 32 * fix shape error
1 parent 871edad commit 4d84205

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

paddle/fluid/operators/lookup_table_v2_op_npu.cc

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,31 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
6565
ctx.template device_context<paddle::platform::NPUDeviceContext>()
6666
.stream();
6767

68-
const auto &runner_zeros =
69-
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
70-
runner_zeros.Run(stream);
71-
72-
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
73-
// can be different tensor, but in cann 20.2+, it does inplace operation.
74-
// Thus, the first input and output should be same tensor.
75-
const auto &runner_scatter =
76-
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
77-
{*table_grad_t}, {{"use_locking", true}});
78-
runner_scatter.Run(stream);
68+
int embedding_dim = table_grad_t->dims()[1];
69+
70+
if (embedding_dim % 32 == 0) {
71+
// NOTE(pangyoki): The embedding_dim of Tensor used in
72+
// EmbeddingDenseGrad must be an integer multiple of 32.
73+
int num_weights = table_grad_t->dims()[0];
74+
const auto &runner =
75+
NpuOpRunner("EmbeddingDenseGrad", {*output_grad_t, *ids_t},
76+
{*table_grad_t}, {{"num_weights", num_weights},
77+
{"padding_idx", -1},
78+
{"scale_grad_by_freq", false}});
79+
runner.Run(stream);
80+
} else {
81+
const auto &runner_zeros =
82+
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
83+
runner_zeros.Run(stream);
84+
85+
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
86+
// can be different tensor, but in cann 20.2+, it does inplace operation.
87+
// Thus, the first input and output should be same tensor.
88+
const auto &runner_scatter =
89+
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
90+
{*table_grad_t}, {{"use_locking", true}});
91+
runner_scatter.Run(stream);
92+
}
7993
}
8094
};
8195
} // namespace operators

python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def setUp(self):
3535
self.place = paddle.NPUPlace(0)
3636

3737
self.init_dtype()
38+
self.init_dim()
3839
np.random.seed(SEED)
3940
bsz = 6
4041
seqlen = 8
4142
vocab = 10
42-
dim = 20
43-
w = np.ones([vocab, dim]).astype(self.dtype)
43+
w = np.ones([vocab, self.dim]).astype(self.dtype)
4444
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
45-
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)
45+
out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype)
4646

4747
self.inputs = {
4848
'W': OpTest.np_dtype_to_fluid_dtype(w),
@@ -62,6 +62,10 @@ def set_npu(self):
6262
def init_dtype(self):
6363
self.dtype = np.float32
6464

65+
def init_dim(self):
66+
# embedding_dim is not multiple of 32
67+
self.dim = 20
68+
6569
def test_check_output(self):
6670
self.check_output_with_place(self.place, check_dygraph=False)
6771

@@ -85,5 +89,29 @@ def set_npu(self):
8589
self.__class__.no_need_check_grad = True
8690

8791

92+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
93+
"core is not compiled with NPU")
94+
class TestLookupTableV2Dim32(TestLookupTableV2):
95+
def init_dim(self):
96+
# embedding_dim is multiple of 32
97+
self.dim = 64
98+
99+
100+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
101+
"core is not compiled with NPU")
102+
class TestLookupTableV2Dim32FP16(TestLookupTableV2):
103+
no_need_check_grad = True
104+
105+
def init_dtype(self):
106+
self.dtype = np.float16
107+
108+
def init_dim(self):
109+
self.dim = 64
110+
111+
def set_npu(self):
112+
self.__class__.use_npu = True
113+
self.__class__.no_need_check_grad = True
114+
115+
88116
if __name__ == '__main__':
89117
unittest.main()

0 commit comments

Comments
 (0)