Skip to content

Commit 479009e

Browse files
zhiqiufrankwhzhang
authored andcommitted
[NPU] add npu kernel for truncated_gaussian_random op (PaddlePaddle#31654)
* init * add todo * add npu kernel for truncated_gaussian_random * add sync * fix concat_grad * fix typo
1 parent 1464a59 commit 479009e

File tree

3 files changed

+186
-5
lines changed

3 files changed

+186
-5
lines changed

paddle/fluid/operators/concat_op_npu.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
8080
axis = ComputeAxis(static_cast<int64_t>(axis),
8181
static_cast<int64_t>(ins[0]->dims().size()));
8282

83-
std::vector<int> sizes;
8483
int offset = 0;
8584
auto stream =
8685
ctx.template device_context<paddle::platform::NPUDeviceContext>()
@@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
9190
if (out_var_names[j] != framework::kEmptyVarName &&
9291
outs[j]->numel() != 0UL) {
9392
outs[j]->mutable_data<T>(ctx.GetPlace());
94-
sizes.push_back(outs[j]->dims()[axis]);
9593
std::vector<int> offsets;
9694
std::vector<int> sizes;
9795
for (int dim = 0; dim < ins[j]->dims().size(); ++dim) {
@@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
103101
sizes.push_back(ins[j]->dims()[dim]);
104102
}
105103
}
106-
auto runner =
107-
NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
108-
{{"offsets", offset}, {"size", ins[j]->dims()[axis]}});
104+
auto runner = NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
105+
{{"offsets", offsets}, {"size", sizes}});
109106
runner.Run(stream);
110107
}
111108
if (ins[j]->numel() != 0UL) {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
16+
#include <memory>
17+
#include <string>
18+
#include "paddle/fluid/operators/npu_op_runner.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename DeviceContext, typename T>
26+
class TruncatedGaussianRandomNPUKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& ctx) const override {
29+
// TODO(zhiqiu): support dynamic shape and call ParameterizedTruncatedNormal
30+
std::vector<int> shape = ctx.Attr<std::vector<int>>("shape");
31+
Tensor shape_tensor(framework::proto::VarType::INT32);
32+
shape_tensor.mutable_data<int32_t>({static_cast<int>(shape.size())},
33+
ctx.GetPlace());
34+
TensorFromVector(shape, ctx.device_context(), &shape_tensor);
35+
float mean = ctx.Attr<float>("mean");
36+
Tensor mean_tensor(framework::proto::VarType::FP32);
37+
mean_tensor.mutable_data<float>({1}, ctx.GetPlace());
38+
TensorFromVector(std::vector<float>{mean}, ctx.device_context(),
39+
&mean_tensor);
40+
41+
float std = ctx.Attr<float>("std");
42+
Tensor std_tensor(framework::proto::VarType::FP32);
43+
std_tensor.mutable_data<float>({1}, ctx.GetPlace());
44+
TensorFromVector(std::vector<float>{std}, ctx.device_context(),
45+
&std_tensor);
46+
47+
int32_t seed_var = ctx.Attr<int32_t>("seed");
48+
49+
Tensor min_tensor(framework::proto::VarType::FP32);
50+
min_tensor.mutable_data<float>({1}, ctx.GetPlace());
51+
float min_value = mean - std * 2.0;
52+
TensorFromVector(std::vector<float>{min_value}, ctx.device_context(),
53+
&min_tensor);
54+
55+
Tensor max_tensor(framework::proto::VarType::FP32);
56+
max_tensor.mutable_data<float>({1}, ctx.GetPlace());
57+
float max_value = mean + std * 2.0;
58+
TensorFromVector(std::vector<float>{max_value}, ctx.device_context(),
59+
&max_tensor);
60+
61+
auto* out = ctx.Output<framework::Tensor>("Out");
62+
out->mutable_data<T>(ctx.GetPlace());
63+
auto stream =
64+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
65+
.stream();
66+
auto runner = NpuOpRunner(
67+
"ParameterizedTruncatedNormal",
68+
{shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor}, {*out},
69+
{{"seed", seed_var}});
70+
runner.Run(stream);
71+
}
72+
};
73+
74+
// NOTE(zhiqiu): actually, this is cpu version kernel, and we need to make the
75+
// above
76+
// npu version work in the future.
77+
template <typename T>
78+
class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
79+
public:
80+
void Compute(const framework::ExecutionContext& context) const override {
81+
float mean = context.Attr<float>("mean");
82+
float std = context.Attr<float>("std");
83+
auto* tensor = context.Output<framework::Tensor>("Out");
84+
tensor->mutable_data<T>(context.GetPlace());
85+
86+
Tensor cpu_tensor(tensor->type());
87+
cpu_tensor.Resize(tensor->dims());
88+
T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace());
89+
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
90+
1.0);
91+
TruncatedNormal<T> truncated_normal(mean, std);
92+
int64_t size = tensor->numel();
93+
94+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
95+
auto engine = framework::GetCPURandomEngine(seed);
96+
for (int64_t i = 0; i < size; ++i) {
97+
cpu_data[i] = truncated_normal(dist(*engine));
98+
}
99+
framework::TensorCopy(
100+
cpu_tensor, context.GetPlace(),
101+
context.template device_context<platform::DeviceContext>(), tensor);
102+
context.template device_context<paddle::platform::NPUDeviceContext>()
103+
.Wait();
104+
}
105+
};
106+
107+
} // namespace operators
108+
} // namespace paddle
109+
110+
namespace ops = paddle::operators;
111+
112+
REGISTER_OP_NPU_KERNEL(truncated_gaussian_random,
113+
ops::NPUTruncatedGaussianRandomKernel<float>);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
from paddle.fluid.op import Operator
26+
from paddle.fluid.executor import Executor
27+
28+
paddle.enable_static()
29+
SEED = 2021
30+
31+
32+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
33+
"core is not compiled with NPU")
34+
class TestTruncatedNormal(unittest.TestCase):
35+
def _test(self, run_npu=True):
36+
main_prog = paddle.static.Program()
37+
startup_prog = paddle.static.Program()
38+
scope = paddle.fluid.core.Scope()
39+
40+
main_prog.random_seed = SEED
41+
startup_prog.random_seed = SEED
42+
np.random.seed(SEED)
43+
paddle.seed(SEED)
44+
45+
with fluid.scope_guard(scope):
46+
with paddle.static.program_guard(main_prog, startup_prog):
47+
weight_attr = paddle.framework.ParamAttr(
48+
name="linear_weight",
49+
initializer=paddle.nn.initializer.TruncatedNormal(
50+
mean=0.0, std=2.0))
51+
linear = paddle.nn.Linear(
52+
2, 2, weight_attr=weight_attr, bias_attr=False)
53+
54+
if run_npu:
55+
place = paddle.NPUPlace(0)
56+
else:
57+
place = paddle.CPUPlace()
58+
59+
exe = paddle.static.Executor(place)
60+
w = exe.run(startup_prog, fetch_list=['linear_weight'])
61+
return w
62+
63+
def test_npu(self):
64+
cpu_w = self._test(False)
65+
npu_w = self._test(True)
66+
67+
self.assertTrue(np.allclose(npu_w, cpu_w))
68+
69+
70+
if __name__ == '__main__':
71+
unittest.main()

0 commit comments

Comments
 (0)