Skip to content

Commit 17862b7

Browse files
authored
[NPU] Support mean npu kernel (#31729)
1 parent 342252c commit 17862b7

3 files changed

Lines changed: 208 additions & 0 deletions

File tree

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,6 @@ endif()
187187

188188
if(WITH_ASCEND_CL)
189189
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
190+
cc_test(mean_op_npu_test SRCS mean_op_npu_test.cc DEPS op_registry mean_op scope device_context enforce executor)
190191
endif()
192+
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/mean_op.h"
13+
#include "paddle/fluid/platform/float16.h"
14+
#include "paddle/fluid/operators/npu_op_runner.h"
15+
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename DeviceContext, typename T>
21+
class MeanNPUKernel : public framework::OpKernel<T> {
22+
public:
23+
void Compute(const framework::ExecutionContext& ctx) const override {
24+
auto* x = ctx.Input<framework::LoDTensor>("X");
25+
auto* out = ctx.Output<framework::LoDTensor>("Out");
26+
27+
std::vector<int> axes;
28+
29+
framework::NPUAttributeMap attr_input = {
30+
{"keep_dims", false},
31+
{"axes", axes}};
32+
33+
out->mutable_data<T>(ctx.GetPlace());
34+
35+
auto runner = NpuOpRunner("ReduceMeanD",
36+
{*x},
37+
{*out},
38+
attr_input);
39+
40+
auto stream =
41+
ctx.template device_context<
42+
paddle::platform::NPUDeviceContext>()
43+
.stream();
44+
runner.Run(stream);
45+
}
46+
};
47+
48+
49+
template <typename DeviceContext, typename T>
50+
class MeanGradNPUKernel : public framework::OpKernel<T> {
51+
public:
52+
void Compute(const framework::ExecutionContext& context) const override {
53+
auto stream =
54+
context.template device_context<
55+
paddle::platform::NPUDeviceContext>()
56+
.stream();
57+
58+
auto grad = context.Input<Tensor>(framework::GradVarName("Out"));
59+
60+
PADDLE_ENFORCE_EQ(grad->numel(), 1,
61+
platform::errors::InvalidArgument(
62+
"Mean Gradient Input Tensor len should be 1. But "
63+
"received Out@Grad's elements num is %d.",
64+
grad->numel()));
65+
66+
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
67+
IG->mutable_data<T>(context.GetPlace());
68+
69+
// ones
70+
Tensor ones(grad->type());
71+
ones.mutable_data<T>(IG->dims(), context.GetPlace());
72+
auto runner_ones = NpuOpRunner("OnesLike", {*IG}, {ones}, {});
73+
runner_ones.Run(stream);
74+
75+
// means
76+
Tensor mean_tensor(grad->type());
77+
mean_tensor.Resize({1});
78+
mean_tensor.mutable_data<T>(context.GetPlace());
79+
std::vector<float> mean_vec;
80+
mean_vec.push_back(1.0/static_cast<float>(IG->numel()));
81+
framework::TensorFromVector(mean_vec,
82+
context.device_context(),
83+
&mean_tensor);
84+
85+
// means mul ones
86+
Tensor mean_ma(grad->type());
87+
mean_ma.Resize(IG->dims());
88+
mean_ma.mutable_data<T>(context.GetPlace());
89+
auto runner_mul_1 = NpuOpRunner("Mul", {mean_tensor, ones}, {mean_ma}, {});
90+
runner_mul_1.Run(stream);
91+
92+
// and mul grad
93+
auto runner_mul_2 = NpuOpRunner("Mul", {mean_ma, *grad}, {*IG}, {});
94+
runner_mul_2.Run(stream);
95+
}
96+
};
97+
98+
99+
} // namespace operators
100+
} // namespace paddle
101+
102+
namespace ops = paddle::operators;
103+
namespace plat = paddle::platform;
104+
REGISTER_OP_NPU_KERNEL(
105+
mean,
106+
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, int>,
107+
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, float>,
108+
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, double>,
109+
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
110+
111+
112+
REGISTER_OP_NPU_KERNEL(
113+
mean_grad,
114+
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
115+
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
116+
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, double>,
117+
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
from paddle.fluid import core
25+
26+
paddle.enable_static()
27+
SEED = 2021
28+
29+
30+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
31+
"core is not compiled with NPU")
32+
class TestMean(OpTest):
33+
def setUp(self):
34+
self.set_npu()
35+
self.place = paddle.NPUPlace(0)
36+
self.op_type = "mean"
37+
self.init_dtype()
38+
39+
x = np.random.random([1, 100]).astype(self.dtype)
40+
self.inputs = {'X': x}
41+
42+
self.attrs = {}
43+
np_out = np.mean(x)
44+
self.outputs = {'Out': np_out}
45+
46+
def set_npu(self):
47+
self.__class__.use_npu = True
48+
49+
def init_dtype(self):
50+
self.dtype = np.float32
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place, check_dygraph=False)
54+
55+
def test_check_grad(self):
56+
self.check_grad_with_place(self.place, ['X'], 'Out', check_dygraph=False)
57+
58+
59+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
60+
"core is not compiled with NPU")
61+
class TestMeanFP16(OpTest):
62+
def setUp(self):
63+
self.set_npu()
64+
self.place = paddle.NPUPlace(0)
65+
self.op_type = "mean"
66+
self.init_dtype()
67+
68+
x = np.random.random([3, 200]).astype(self.dtype)
69+
self.inputs = {'X': x}
70+
71+
self.attrs = {}
72+
np_out = np.mean(x)
73+
self.outputs = {'Out': np_out}
74+
75+
def set_npu(self):
76+
self.__class__.use_npu = True
77+
self.__class__.no_need_check_grad = True
78+
79+
def init_dtype(self):
80+
self.dtype = np.float16
81+
82+
def test_check_output(self):
83+
self.check_output_with_place(self.place, check_dygraph=False)
84+
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()
89+

0 commit comments

Comments
 (0)