Skip to content

Commit 3dd992e

Browse files
authored
[NPU] Support npu op expand (#31405)
* [npu] support npu kernel for `expand`
1 parent 444c285 commit 3dd992e

File tree

5 files changed

+305
-3
lines changed

5 files changed

+305
-3
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,11 @@ cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
156156
if (WITH_PYTHON)
157157
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
158158
endif()
159-
if (WITH_ASCEND_CL)
160-
cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op)
161-
endif()
162159

163160
if (WITH_ASCEND_CL)
164161
cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor)
162+
cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op)
163+
cc_test(expand_op_npu_test SRCS expand_op_npu_test.cc DEPS op_registry expand_op scope device_context enforce executor compare_op)
165164
endif()
166165

167166
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")

paddle/fluid/operators/expand_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ inline std::vector<int> get_expand_times(
5656
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
5757
expand_data = cpu_expand_tensor.data<int>();
5858
}
59+
#ifdef PADDLE_WITH_ASCEND_CL
60+
if (platform::is_npu_place(expand_tensor->place())) {
61+
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
62+
expand_data = cpu_expand_tensor.data<int>();
63+
}
64+
#endif
5965
#ifdef PADDLE_WITH_XPU
6066
if (platform::is_xpu_place(expand_tensor->place())) {
6167
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#ifdef PADDLE_WITH_ASCEND_CL
15+
#include <iostream>
16+
#include <memory>
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/expand_op.h"
21+
#include "paddle/fluid/operators/npu_op_runner.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
template <typename DeviceContext, typename T>
27+
class ExpandNPUKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext& context) const override {
30+
auto rank = context.Input<Tensor>("X")->dims().size();
31+
PADDLE_ENFORCE_GE(
32+
rank, 1,
33+
platform::errors::InvalidArgument(
34+
"The number of dimensions of the input 'x' for Op(expand) "
35+
"must be greater than or equal to 1, but the value received is %d.",
36+
rank));
37+
PADDLE_ENFORCE_LE(
38+
rank, MAX_RANK_SUPPORTED,
39+
platform::errors::InvalidArgument(
40+
"The number of dimensions of the input 'x' for Op(expand) "
41+
"must be less than or equal to %d, but the value received is %d.",
42+
MAX_RANK_SUPPORTED, rank));
43+
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
44+
}
45+
46+
protected:
47+
template <int Rank>
48+
void Expand(const framework::ExecutionContext& context) const {
49+
auto* in0 = context.Input<framework::LoDTensor>("X");
50+
auto in_dims = in0->dims();
51+
auto expand_times = get_expand_times(context);
52+
PADDLE_ENFORCE_EQ(
53+
static_cast<size_t>(in_dims.size()), expand_times.size(),
54+
platform::errors::InvalidArgument(
55+
"The number of elements (%d) of 'expand_times' for "
56+
"Op(expand) must be equal to the number "
57+
"of dimensions (%d) of the input.",
58+
expand_times.size(), static_cast<size_t>(in_dims.size())));
59+
auto* out0 = context.Output<framework::LoDTensor>("Out");
60+
framework::DDim out_dims(in_dims);
61+
for (size_t i = 0; i < expand_times.size(); ++i) {
62+
out_dims[i] *= expand_times[i];
63+
}
64+
out0->Resize(out_dims);
65+
out0->mutable_data<T>(context.device_context().GetPlace());
66+
auto runner = NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}});
67+
auto stream =
68+
context.template device_context<paddle::platform::NPUDeviceContext>()
69+
.stream();
70+
runner.Run(stream);
71+
}
72+
};
73+
} // namespace operators
74+
} // namespace paddle
75+
76+
namespace ops = paddle::operators;
77+
REGISTER_OP_NPU_KERNEL(
78+
expand, ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, float>,
79+
ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext,
80+
paddle::platform::float16>);
81+
82+
#endif
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#ifndef _WIN32
15+
#include <unistd.h>
16+
#endif
17+
18+
#include <iostream>
19+
#include <string>
20+
#include <thread> // NOLINT
21+
#include <vector>
22+
23+
#include "gtest/gtest.h"
24+
#include "paddle/fluid/framework/op_registry.h"
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/framework/program_desc.h"
27+
#include "paddle/fluid/operators/dropout_op.h"
28+
#include "paddle/fluid/operators/math/math_function.h"
29+
#include "paddle/fluid/string/printf.h"
30+
31+
namespace f = paddle::framework;
32+
namespace p = paddle::platform;
33+
namespace m = paddle::operators::math;
34+
35+
USE_OP(expand);
36+
USE_OP_DEVICE_KERNEL(expand, NPU);
37+
38+
template <typename T>
39+
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
40+
// init
41+
auto in = scope->Var("X");
42+
auto expand_times = scope->Var("ExpandTimes");
43+
auto out = scope->Var("Out");
44+
auto in_t = in->GetMutable<f::LoDTensor>();
45+
auto out_t = out->GetMutable<f::LoDTensor>();
46+
auto expand_times_t = expand_times->GetMutable<f::LoDTensor>();
47+
48+
auto place = ctx.GetPlace();
49+
TensorFromVector(std::vector<T>(3 * 1 * 7, 1), ctx, in_t);
50+
TensorFromVector(std::vector<int>({1, 10, 1}), ctx, expand_times_t);
51+
52+
in_t->Resize(f::make_ddim({3, 1, 7}));
53+
expand_times_t->Resize(f::make_ddim({3}));
54+
out_t->Resize(f::make_ddim({3, 10, 7}));
55+
out_t->mutable_data<T>(place);
56+
57+
f::AttributeMap attrs = {{}};
58+
auto op = f::OpRegistry::CreateOp(
59+
"expand", {{"X", {"X"}}, {"ExpandTimes", {"ExpandTimes"}}},
60+
{{"Out", {"Out"}}}, attrs);
61+
op->Run(*scope, place);
62+
ctx.Wait();
63+
64+
auto out_dim = out_t->dims();
65+
EXPECT_EQ(out_dim.at(0), 3);
66+
EXPECT_EQ(out_dim.at(1), 10);
67+
EXPECT_EQ(out_dim.at(2), 7);
68+
}
69+
70+
TEST(expand, NPU_fp32) {
71+
f::Scope scope;
72+
p::NPUDeviceContext ctx(p::NPUPlace(0));
73+
Compare<float>(&scope, ctx);
74+
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
30+
"core is not compiled with NPU")
31+
class TestExpand(OpTest):
32+
def setUp(self):
33+
self.set_npu()
34+
self.op_type = "expand"
35+
self.place = paddle.NPUPlace(0)
36+
37+
self.init_dtype()
38+
np.random.seed(SEED)
39+
x = np.random.randn(3,1,7).astype(self.dtype)
40+
out = np.tile(x, [1,10,1])
41+
42+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
43+
self.attrs = {'expand_times': [1,10,1]}
44+
self.outputs = {'Out': out}
45+
46+
def set_npu(self):
47+
self.__class__.use_npu = True
48+
49+
def init_dtype(self):
50+
self.dtype = np.float32
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place, check_dygraph=False)
54+
55+
# TODO(ascendrc): Add grad test
56+
# def test_check_grad(self):
57+
# if self.dtype == np.float16:
58+
# return
59+
# self.check_grad(['X'], 'Out')
60+
#
61+
62+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
63+
"core is not compiled with NPU")
64+
class TestExpandV2(TestExpand):
65+
def setUp(self):
66+
self.set_npu()
67+
self.op_type = "expand"
68+
self.place = paddle.NPUPlace(0)
69+
70+
self.init_dtype()
71+
np.random.seed(SEED)
72+
x = np.random.randn(3,1,7).astype(self.dtype)
73+
out = np.tile(x, [1,10,1])
74+
expand_times = np.array([1,10,1]).astype(np.int32)
75+
76+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x), 'ExpandTimes': OpTest.np_dtype_to_fluid_dtype(expand_times)}
77+
self.attrs = {}
78+
self.outputs = {'Out': out}
79+
80+
81+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
82+
"core is not compiled with NPU")
83+
class TestExpandFp16(TestExpand):
84+
no_need_check_grad = True
85+
def init_dtype(self):
86+
self.dtype = np.float16
87+
88+
89+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
90+
"core is not compiled with NPU")
91+
class TestExpandNet(unittest.TestCase):
92+
def _test(self, run_npu=True):
93+
main_prog = paddle.static.Program()
94+
startup_prog = paddle.static.Program()
95+
main_prog.random_seed = SEED
96+
startup_prog.random_seed = SEED
97+
np.random.seed(SEED)
98+
99+
a_np = np.random.random(size=(32, 1)).astype('float32')
100+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
101+
102+
with paddle.static.program_guard(main_prog, startup_prog):
103+
a = paddle.static.data(name="a", shape=[32, 1], dtype='float32')
104+
label = paddle.static.data(
105+
name="label", shape=[32, 1], dtype='int64')
106+
107+
res = paddle.fluid.layers.expand(a, [1,32])
108+
loss = res.sum()
109+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
110+
sgd.minimize(loss)
111+
112+
if run_npu:
113+
place = paddle.NPUPlace(0)
114+
else:
115+
place = paddle.CPUPlace()
116+
117+
exe = paddle.static.Executor(place)
118+
exe.run(startup_prog)
119+
120+
for epoch in range(100):
121+
122+
loss_res = exe.run(
123+
main_prog,
124+
feed={"a": a_np,
125+
"label": label_np},
126+
fetch_list=[loss])
127+
if epoch % 10 == 0:
128+
print("Epoch {} | Loss: {}".format(epoch, loss))
129+
130+
return loss_res
131+
132+
def test_npu(self):
133+
cpu_loss = self._test(False)
134+
npu_loss = self._test(True)
135+
136+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
137+
138+
139+
if __name__ == '__main__':
140+
unittest.main()
141+

0 commit comments

Comments
 (0)