From 7f4f495ea55201b8216297e57afc7df782da20d4 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Tue, 9 Mar 2021 13:07:28 +0800 Subject: [PATCH 1/7] [npu] support npu kernel `table_lookup_v2` --- .../fluid/operators/lookup_table_v2_op_npu.cc | 116 +++++++++++++ .../operators/lookup_table_v2_op_npu_test.cc | 162 ++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu.cc create mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu_test.cc diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc new file mode 100644 index 00000000000000..6bd7098d44313a --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class LookupTableV2NPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); // int tensor + auto *output_t = ctx.Output("Out"); // float tensor + auto *table_t = ctx.Input("W"); + auto *table_var = ctx.InputVar("W"); + PADDLE_ENFORCE_EQ( + table_var->IsType(), true, + platform::errors::InvalidArgument("npu only accept LoDTensor")); + output_t->mutable_data(ctx.GetPlace()); + framework::NPUAttributeMap attr_input = {{"validate_indices", false}}; + + auto runner = + NpuOpRunner("Gather", {*table_t, *ids_t}, {*output_t}, attr_input); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class LookupTableV2GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); + auto *output_grad_t = + ctx.Input(framework::GradVarName("Out")); + auto *table_t = ctx.Input("W"); + auto *table_grad_t = + ctx.Output(framework::GradVarName("W")); + /* + auto *table_var = ctx.InputVar("W"); + PADDLE_ENFORCE_EQ(table_var->IsType(), true, + platform::errors::InvalidArgument("npu only accept LoDTensor")); + */ + framework::NPUAttributeMap attr_input = {{"use_locking", true}}; + std::vector vec; + std::vector vec_int; + + TensorToVector(*table_t, ctx.device_context(), &vec); + /* + for (auto& v : vec){ + std::cout <<"table_t"<< v<() + .stream(); + runner.Run(stream); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + lookup_table_v2, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel); + +REGISTER_OP_NPU_KERNEL(lookup_table_v2_grad, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel); + +#endif diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc new file mode 100644 index 00000000000000..622f42bfe3348a --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(lookup_table_v2); +USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto ids = scope->Var("Ids"); + auto out = scope->Var("Out"); + auto w = scope->Var("W"); + + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto w_t = w->GetMutable(); + int bsz = 10; + int dim = 32; + int seqlen = 8; + int vocab_size = 100; + TensorFromVector(std::vector(bsz * seqlen, 3), ctx, ids_t); + std::vector val(vocab_size * dim, 10.); + TensorFromVector(val, ctx, w_t); + ids_t->Resize({bsz, seqlen}); + w_t->Resize({vocab_size, dim}); + out_t->Resize({bsz, seqlen, dim}); + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp("lookup_table_v2", + {{"W", {"W"}}, {"Ids", {"Ids"}}}, + {{"Out", {"Out"}}}, attrs); + op->Run(*scope, place); + std::vector out_v; + TensorToVector(*out_t, ctx, &out_v); + ctx.Wait(); + EXPECT_EQ(out_t->numel(), bsz * seqlen * dim); + /* + for (auto v: val){ + std::cout << "inp " << v << std::endl; + } + for (auto v: out_v){ + std::cout <<"res "<< v< +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto w = scope->Var("W"); + auto ids = scope->Var("Ids"); + auto out = scope->Var("DOut"); + auto dw = scope->Var("DW"); + + auto w_t = w->GetMutable(); + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto dw_t = dw->GetMutable(); + + int bsz = 2; + int dim = 2; + int seqlen = 2; + int vocab_size = 4; + + std::vector val_int(bsz * seqlen, 3); + std::vector val(vocab_size * dim, 0.); + std::vector val_out(bsz * seqlen * dim, 1.); + + TensorFromVector(val_int, ctx, ids_t); + TensorFromVector(val, ctx, w_t); + TensorFromVector(val, ctx, dw_t); + TensorFromVector(val_out, ctx, out_t); + + w_t->Resize({vocab_size, dim}); + ids_t->Resize({bsz, seqlen}); + out_t->Resize({bsz, seqlen, dim}); + dw_t->Resize({vocab_size, dim}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + w_t->mutable_data(place); + dw_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp( + "lookup_table_v2_grad", + {{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}}, + {{"W@GRAD", {"DW"}}}, attrs); + op->Run(*scope, place); + ctx.Wait(); + std::vector w_v; + TensorToVector(*dw_t, ctx, &w_v); + ctx.Wait(); + EXPECT_EQ(dw_t->numel(), vocab_size * dim); + /* + for (auto v: val){ + std::cout << "val " << v << std::endl; + } + for (auto v: val_int){ + std::cout << "val int " << v << std::endl; + } + for (auto v: val_out){ + std::cout << "val out " << v << std::endl; + } + for (auto v: w_v){ + std::cout <<"grad res "<< v<(&scope, ctx); +} + +TEST(lookup_table_v2_grad, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx); +} From fff63a57b5cb486432394a016ef21a80b2a5861d Mon Sep 17 00:00:00 2001 From: Meiyim Date: Tue, 9 Mar 2021 14:16:33 +0800 Subject: [PATCH 2/7] clean up --- .../fluid/operators/lookup_table_v2_op_npu.cc | 46 ++++--------------- .../operators/lookup_table_v2_op_npu_test.cc | 2 + 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 6bd7098d44313a..0b5ebe659e59a0 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -1,15 +1,16 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL #include #include #include @@ -53,34 +54,11 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { auto *table_t = ctx.Input("W"); auto *table_grad_t = ctx.Output(framework::GradVarName("W")); - /* - auto *table_var = ctx.InputVar("W"); - PADDLE_ENFORCE_EQ(table_var->IsType(), true, - platform::errors::InvalidArgument("npu only accept LoDTensor")); - */ framework::NPUAttributeMap attr_input = {{"use_locking", true}}; std::vector vec; std::vector vec_int; TensorToVector(*table_t, ctx.device_context(), &vec); - /* - for (auto& v : vec){ - std::cout <<"table_t"<< v<, - ops::LookupTableV2NPUKernel, ops::LookupTableV2NPUKernel, ops::LookupTableV2NPUKernel, ops::LookupTableV2NPUKernel, - ops::LookupTableV2NPUKernel, - ops::LookupTableV2NPUKernel); + ops::LookupTableV2NPUKernel); -REGISTER_OP_NPU_KERNEL(lookup_table_v2_grad, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel); - -#endif +REGISTER_OP_NPU_KERNEL( + lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc index 622f42bfe3348a..eff3c201ac1840 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc @@ -2,7 +2,9 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. From b881071d4f3ed6ff3e3c2e2878a4ead72a3cd78e Mon Sep 17 00:00:00 2001 From: Meiyim Date: Wed, 10 Mar 2021 15:36:19 +0800 Subject: [PATCH 3/7] +python test --- .../npu/test_lookup_table_v2_op_npu.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py new file mode 100644 index 00000000000000..af872104abdd83 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "lookup_table_v2" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + bsz=2 + seqlen=2 + vocab=3 + dim=2 + w = np.ones([vocab, dim]).astype(self.dtype) + x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) + out = np.ones([bsz, seqlen, dim]).astype(self.dtype) + + self.inputs = {'W': OpTest.np_dtype_to_fluid_dtype(w), 'Ids': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = { + 'is_sparse': False, + 'is_distributed': False, + 'remote_prefetch':False, + 'padding_idx': -1 + } + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + # TODO(ascendrc): Add grad test + # def test_check_grad(self): + # if self.dtype == np.float16: + # return + # self.check_grad(['X'], 'Out') + # + + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2Net(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + bsz=3 + seqlen=2 + vocab=3 + dim=2 + + ids_np = np.random.randint(0, vocab, size=(bsz, seqlen)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + emb = paddle.nn.Embedding(vocab, dim) + ids = paddle.static.data(name="ids", shape=[bsz, seqlen], dtype='int64') + #res = paddle.static.nn.embedding(ids, (vocab, dim), param_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.))) + res = emb(ids) + loss = res.sum() + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + for epoch in range(1): + loss_res, w = exe.run( + main_prog, + feed={"ids": ids_np}, + fetch_list=[loss, emb.weight]) + if epoch % 10 == 0: + print(w) + print("Epoch {} | Loss: {}".format(epoch, loss)) + + return loss_res + + def test_npu(self): + cpu_loss = self._test(False) + npu_loss = self._test(True) + self.assertTrue(np.allclose(npu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() + From df03165531bb87780cd393efcaccab47af149b14 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Fri, 12 Mar 2021 14:12:09 +0800 Subject: [PATCH 4/7] +cmake --- paddle/fluid/operators/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 17234edb116e3e..f5c6ccdcd5cdbe 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -155,6 +155,9 @@ cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) endif() +if (WITH_ASCEND_CL) + cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op) +endif() set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") add_subdirectory(benchmark) From fac5d03f0770e709b7c6c96c609f738a28e9aeb3 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Fri, 12 Mar 2021 14:32:32 +0800 Subject: [PATCH 5/7] clean up --- .../operators/lookup_table_v2_op_npu_test.cc | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc index eff3c201ac1840..f37915834bd756 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc @@ -70,14 +70,6 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { TensorToVector(*out_t, ctx, &out_v); ctx.Wait(); EXPECT_EQ(out_t->numel(), bsz * seqlen * dim); - /* - for (auto v: val){ - std::cout << "inp " << v << std::endl; - } - for (auto v: out_v){ - std::cout <<"res "<< v<numel(), vocab_size * dim); - /* - for (auto v: val){ - std::cout << "val " << v << std::endl; - } - for (auto v: val_int){ - std::cout << "val int " << v << std::endl; - } - for (auto v: val_out){ - std::cout << "val out " << v << std::endl; - } - for (auto v: w_v){ - std::cout <<"grad res "<< v< Date: Fri, 12 Mar 2021 16:18:47 +0800 Subject: [PATCH 6/7] remove int8 kernel + python unitest for fp16 --- .../fluid/operators/lookup_table_v2_op_npu.cc | 14 ++----------- .../npu/test_lookup_table_v2_op_npu.py | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 0b5ebe659e59a0..e7cc048ed3ce4b 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -55,10 +55,6 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { auto *table_grad_t = ctx.Output(framework::GradVarName("W")); framework::NPUAttributeMap attr_input = {{"use_locking", true}}; - std::vector vec; - std::vector vec_int; - - TensorToVector(*table_t, ctx.device_context(), &vec); auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t}, {*table_grad_t}, attr_input); @@ -77,14 +73,8 @@ REGISTER_OP_NPU_KERNEL( lookup_table_v2, ops::LookupTableV2NPUKernel, ops::LookupTableV2NPUKernel, - ops::LookupTableV2NPUKernel, - ops::LookupTableV2NPUKernel, - ops::LookupTableV2NPUKernel); + paddle::platform::float16>); REGISTER_OP_NPU_KERNEL( lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel, - ops::LookupTableV2GradNPUKernel); + ops::LookupTableV2GradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index af872104abdd83..6da70ba72ca3b0 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -67,8 +67,25 @@ def test_check_output(self): # if self.dtype == np.float16: # return # self.check_grad(['X'], 'Out') - # +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2FP16(TestLookupTableV2): + no_need_check_grad = True + def init_dtype(self): + self.dtype = np.float16 + +#@unittest.skipIf(not paddle.is_compiled_with_npu(), +# "core is not compiled with NPU") +#class TestLookupTableV2Int8(TestLookupTableV2): +# def init_dtype(self): +# self.dtype = np.int8 +# +#@unittest.skipIf(not paddle.is_compiled_with_npu(), +# "core is not compiled with NPU") +#class TestLookupTableV2UInt8(TestLookupTableV2): +# def init_dtype(self): +# self.dtype = np.uint8 @unittest.skipIf(not paddle.is_compiled_with_npu(), @@ -120,6 +137,7 @@ def test_npu(self): self.assertTrue(np.allclose(npu_loss, cpu_loss)) + if __name__ == '__main__': unittest.main() From dfb633b5ed02591d166c9a3439e682c7b3a2dfd7 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Mon, 15 Mar 2021 11:21:55 +0800 Subject: [PATCH 7/7] clean up --- .../fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 6da70ba72ca3b0..99016e5d620c83 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -108,7 +108,6 @@ def _test(self, run_npu=True): with paddle.static.program_guard(main_prog, startup_prog): emb = paddle.nn.Embedding(vocab, dim) ids = paddle.static.data(name="ids", shape=[bsz, seqlen], dtype='int64') - #res = paddle.static.nn.embedding(ids, (vocab, dim), param_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.))) res = emb(ids) loss = res.sum()