Skip to content

Commit 8c4b97c

Browse files
OleNetoyjxer
authored andcommitted
[NPU] add Increment op (PaddlePaddle#31563)
* add increment * fix * update test increment op inplace * update increment op * increment b = 2 Co-authored-by: oyjxer <[email protected]>
1 parent ee43974 commit 8c4b97c

File tree

3 files changed

+278
-0
lines changed

3 files changed

+278
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
16+
#include "paddle/fluid/operators/increment_op.h"
17+
#include "paddle/fluid/platform/float16.h"
18+
#include "paddle/fluid/operators/npu_op_runner.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
class OpDesc;
23+
class Variable;
24+
} // namespace framework
25+
namespace imperative {
26+
class OpBase;
27+
} // namespace imperative
28+
} // namespace paddle
29+
30+
namespace paddle {
31+
namespace operators {
32+
33+
34+
template <typename DeviceContext, typename T>
35+
class IncrementalNPUKernel : public framework::OpKernel<T> {
36+
public:
37+
void Compute(const framework::ExecutionContext& context) const override {
38+
auto* x_tensor = context.Input<framework::Tensor>("X");
39+
auto* out_tensor = context.Output<framework::Tensor>("Out");
40+
float step = context.Attr<float>("step");
41+
out_tensor->mutable_data<T>(context.GetPlace());
42+
43+
Tensor step_tensor(x_tensor->type());
44+
std::vector<T> step_vec;
45+
step_vec.push_back(static_cast<T>(step));
46+
framework::TensorFromVector(
47+
step_vec,
48+
context.device_context(),
49+
&step_tensor);
50+
51+
auto runner = NpuOpRunner("Add",
52+
{*x_tensor, step_tensor},
53+
{*out_tensor},
54+
{});
55+
56+
auto stream =
57+
context.template device_context<paddle::platform::NPUDeviceContext>()
58+
.stream();
59+
runner.Run(stream);
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle
65+
66+
67+
namespace plat = paddle::platform;
68+
namespace ops = paddle::operators;
69+
70+
REGISTER_OP_NPU_KERNEL(
71+
increment,
72+
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, float>,
73+
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, double>,
74+
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int>,
75+
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
76+
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
77+
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifndef _WIN32
16+
#include <unistd.h>
17+
#endif
18+
19+
#include <string>
20+
#include <thread> // NOLINT
21+
#include <vector>
22+
23+
#include "gtest/gtest.h"
24+
#include "paddle/fluid/framework/op_registry.h"
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/framework/program_desc.h"
27+
#include "paddle/fluid/operators/dropout_op.h"
28+
#include "paddle/fluid/operators/math/math_function.h"
29+
#include "paddle/fluid/string/printf.h"
30+
31+
namespace f = paddle::framework;
32+
namespace p = paddle::platform;
33+
namespace m = paddle::operators::math;
34+
35+
USE_OP(increment);
36+
USE_OP_DEVICE_KERNEL(increment, NPU);
37+
38+
template <typename T>
39+
void Compare(f::Scope* scope, const p::DeviceContext& ctx,
40+
std::string op_type) {
41+
// init
42+
auto x = scope->Var("X");
43+
auto tensor_x = x->GetMutable<f::LoDTensor>();
44+
45+
std::vector<T> init;
46+
init.push_back(static_cast<T>(1.0));
47+
48+
TensorFromVector(init, ctx, tensor_x);
49+
tensor_x->Resize({1});
50+
51+
ctx.Wait();
52+
53+
auto place = ctx.GetPlace();
54+
auto out = scope->Var("Out");
55+
auto tensor_out = out->GetMutable<f::LoDTensor>();
56+
57+
f::AttributeMap attr_input = { {"step", static_cast<float>(2.0)} };
58+
auto op = f::OpRegistry::CreateOp("increment", {{"X", {"X"}}},
59+
{{"Out", {"Out"}}},
60+
attr_input);
61+
62+
op->Run(*scope, place);
63+
64+
std::vector<T> out_vec;
65+
TensorToVector(*tensor_out, ctx, &out_vec);
66+
67+
ctx.Wait();
68+
69+
EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)1);
70+
EXPECT_EQ(out_vec[0], static_cast<T>(3.0));
71+
}
72+
73+
74+
TEST(increment, NPU_fp32) {
75+
f::Scope scope;
76+
p::NPUDeviceContext ctx(p::NPUPlace(0));
77+
Compare<float>(&scope, ctx, "increment");
78+
}
79+
80+
TEST(increment, NPU_fp64) {
81+
f::Scope scope;
82+
p::NPUDeviceContext ctx(p::NPUPlace(0));
83+
Compare<float>(&scope, ctx, "increment");
84+
}
85+
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
from paddle.fluid import core
25+
26+
paddle.enable_static()
27+
SEED = 2021
28+
29+
NPUPlace = 5
30+
31+
32+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
33+
"core is not compiled with NPU")
34+
class TestIncrement(OpTest):
35+
def setUp(self):
36+
self.set_npu()
37+
self.place = paddle.NPUPlace(NPUPlace)
38+
self.op_type = "increment"
39+
self.init_dtype()
40+
41+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), }
42+
43+
self.attrs = {"Step": 1}
44+
self.outputs = {'Out': np.array([2])}
45+
46+
def set_npu(self):
47+
self.__class__.use_npu = True
48+
self.__class__.no_need_check_grad = True
49+
50+
def init_dtype(self):
51+
self.dtype = np.int64
52+
53+
def test_check_output(self):
54+
self.check_output_with_place(self.place, check_dygraph=False)
55+
56+
57+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
58+
"core is not compiled with NPU")
59+
class TestIncrementFP16(OpTest):
60+
def setUp(self):
61+
self.set_npu()
62+
self.place = paddle.NPUPlace(NPUPlace)
63+
self.op_type = "increment"
64+
self.init_dtype()
65+
66+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), }
67+
self.pre_input_id = id(self.inputs['X'])
68+
69+
self.attrs = {"Step": 1}
70+
self.outputs = {'Out': np.array([2])}
71+
72+
def set_npu(self):
73+
self.__class__.use_npu = True
74+
75+
def init_dtype(self):
76+
self.dtype = np.float16
77+
78+
def test_check_output(self):
79+
self.check_output_with_place(self.place, check_dygraph=False)
80+
81+
82+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
83+
"core is not compiled with NPU")
84+
class TestIncrementInplace(unittest.TestCase):
85+
def test_npu(self):
86+
main_prog = paddle.static.Program()
87+
startup_prog = paddle.static.Program()
88+
main_prog.random_seed = SEED
89+
startup_prog.random_seed = SEED
90+
np.random.seed(SEED)
91+
92+
a_np = np.array([1]).astype('float32')
93+
94+
with paddle.static.program_guard(main_prog, startup_prog):
95+
a = paddle.static.data(name="a", shape=[1], dtype='float32')
96+
b = fluid.layers.increment(a)
97+
98+
place = paddle.NPUPlace(NPUPlace)
99+
100+
exe = paddle.static.Executor(place)
101+
exe.run(startup_prog)
102+
103+
b_value = exe.run(
104+
main_prog,
105+
feed={"a": a_np,},
106+
fetch_list=[b])
107+
108+
print('input a id is : {}'.format(id(a)))
109+
print('input b id is : {}'.format(id(b)))
110+
111+
self.assertEqual(id(a), id(b))
112+
self.assertEqual(b_value[0], 2)
113+
114+
115+
if __name__ == '__main__':
116+
unittest.main()

0 commit comments

Comments
 (0)