Skip to content

Commit 382fc31

Browse files
authored
【NPU】Support npu op gelu and gelu_grad (#31530)
* Support npu op gelu and gelu_grad * Support npu op gelu and gelu_grad
1 parent 5d29a27 commit 382fc31

File tree

4 files changed

+422
-0
lines changed

4 files changed

+422
-0
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,7 @@ if(WITH_UNITY_BUILD)
179179
# The specified link dependency needs to be displayed here.
180180
target_link_libraries(paddle_operators_unity ${OP_HEADER_DEPS} ${COMMON_OP_DEPS})
181181
endif()
182+
183+
if(WITH_ASCEND_CL)
184+
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
185+
endif()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <memory>
16+
#include <string>
17+
18+
#include "paddle/fluid/operators/gelu_op.h"
19+
#include "paddle/fluid/operators/npu_op_runner.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
26+
template <typename DeviceContext, typename T>
27+
class GeluNPUKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext& ctx) const override {
30+
auto* x = ctx.Input<Tensor>("X");
31+
32+
auto* out = ctx.Output<Tensor>("Out");
33+
34+
auto place = ctx.GetPlace();
35+
36+
out->mutable_data<T>(place);
37+
38+
auto stream =
39+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
40+
.stream();
41+
42+
auto runner = NpuOpRunner("Gelu", {*x}, {*out}, {});
43+
runner.Run(stream);
44+
}
45+
};
46+
47+
template <typename DeviceContext, typename T>
48+
class GeluGradNPUKernel : public framework::OpKernel<T> {
49+
public:
50+
void Compute(const framework::ExecutionContext& ctx) const override {
51+
auto* x = ctx.Input<Tensor>("X");
52+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
53+
54+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
55+
56+
auto place = ctx.GetPlace();
57+
58+
dx->mutable_data<T>(place);
59+
60+
auto stream =
61+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
62+
.stream();
63+
64+
Tensor out(x->type());
65+
out.mutable_data<T>(x->dims(), place);
66+
auto out_runner = NpuOpRunner("Gelu", {*x}, {out}, {});
67+
out_runner.Run(stream);
68+
69+
auto dx_runner = NpuOpRunner("GeluGrad", {*dout, *x, out}, {*dx}, {});
70+
dx_runner.Run(stream);
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
namespace ops = paddle::operators;
78+
79+
REGISTER_OP_NPU_KERNEL(
80+
gelu,
81+
ops::GeluNPUKernel<paddle::platform::NPUDeviceContext, float>,
82+
ops::GeluNPUKernel<paddle::platform::NPUDeviceContext,
83+
paddle::platform::float16>);
84+
85+
REGISTER_OP_NPU_KERNEL(
86+
gelu_grad,
87+
ops::GeluGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
88+
ops::GeluGradNPUKernel<paddle::platform::NPUDeviceContext,
89+
paddle::platform::float16>);
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifndef _WIN32
16+
#include <unistd.h>
17+
#endif
18+
19+
#include <string>
20+
#include <thread> // NOLINT
21+
#include <vector>
22+
23+
#include "gtest/gtest.h"
24+
#include "paddle/fluid/framework/op_registry.h"
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/framework/program_desc.h"
27+
#include "paddle/fluid/operators/dropout_op.h"
28+
#include "paddle/fluid/operators/math/math_function.h"
29+
#include "paddle/fluid/string/printf.h"
30+
31+
namespace f = paddle::framework;
32+
namespace p = paddle::platform;
33+
namespace m = paddle::operators::math;
34+
35+
USE_OP(gelu);
36+
USE_OP_DEVICE_KERNEL(gelu, NPU);
37+
38+
template <typename T>
39+
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
40+
// init
41+
auto x = scope->Var("X");
42+
auto tensor_x = x->GetMutable<f::LoDTensor>();
43+
44+
std::vector<T> init_x;
45+
for (int64_t i = 0; i < 10 * 10; ++i) {
46+
init_x.push_back(static_cast<T>(1.0));
47+
}
48+
49+
TensorFromVector(init_x, ctx, tensor_x);
50+
tensor_x->Resize({10, 10});
51+
52+
auto out = scope->Var("Out");
53+
auto tensor_out = out->GetMutable<f::LoDTensor>();
54+
55+
f::AttributeMap attrs;
56+
57+
ctx.Wait();
58+
59+
// run
60+
auto place = ctx.GetPlace();
61+
62+
auto op = f::OpRegistry::CreateOp("gelu", {{"X", {"X"}}},
63+
{{"Out", {"Out"}}}, attrs);
64+
op->Run(*scope, place);
65+
66+
ctx.Wait();
67+
68+
// eval time
69+
struct timeval start, end;
70+
gettimeofday(&start, NULL);
71+
72+
for (int i = 0; i < 100; i++) {
73+
op->Run(*scope, place);
74+
}
75+
76+
ctx.Wait();
77+
78+
gettimeofday(&end, NULL);
79+
int micros = (((end.tv_sec - start.tv_sec) * 1000000) +
80+
end.tv_usec) - (start.tv_usec);
81+
printf("used time: %d\n", micros / 100);
82+
83+
// eval value
84+
std::vector<T> out_vec;
85+
TensorToVector(*tensor_out, ctx, &out_vec);
86+
87+
float expected = 0.841192;
88+
for (uint32_t i = 0; i < out_vec.size(); i++) {
89+
EXPECT_FLOAT_EQ(out_vec[i], static_cast<T>(expected));
90+
}
91+
}
92+
93+
template <typename T>
94+
void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
95+
auto dout = scope->Var("DOut");
96+
auto tensor_dout = dout->GetMutable<f::LoDTensor>();
97+
98+
auto x = scope->Var("X");
99+
auto tensor_x = x->GetMutable<f::LoDTensor>();
100+
101+
std::vector<T> init_dout;
102+
for (int64_t i = 0; i < 10 * 10; ++i) {
103+
init_dout.push_back(static_cast<T>(1.0));
104+
}
105+
106+
std::vector<T> init_x;
107+
for (int64_t i = 0; i < 10 * 10; ++i) {
108+
init_x.push_back(static_cast<T>(1.0));
109+
}
110+
111+
TensorFromVector(init_dout, ctx, tensor_dout);
112+
tensor_dout->Resize({10, 10});
113+
TensorFromVector(init_x, ctx, tensor_x);
114+
tensor_x->Resize({10, 10});
115+
116+
auto dx = scope->Var("DX");
117+
auto tensor_dx = dx->GetMutable<f::LoDTensor>();
118+
119+
f::AttributeMap attrs;
120+
121+
ctx.Wait();
122+
123+
// run
124+
auto place = ctx.GetPlace();
125+
126+
auto op = f::OpRegistry::CreateOp("gelu_grad",
127+
{{"Out@GRAD", {"DOut"}}, {"X", {"X"}}},
128+
{{"X@GRAD", {"DX"}}}, attrs);
129+
op->Run(*scope, place);
130+
131+
ctx.Wait();
132+
133+
// eval time
134+
struct timeval start, end;
135+
gettimeofday(&start, NULL);
136+
137+
for (int i = 0; i < 100; i++) {
138+
op->Run(*scope, place);
139+
}
140+
141+
ctx.Wait();
142+
143+
gettimeofday(&end, NULL);
144+
int micros = (((end.tv_sec - start.tv_sec) * 1000000) +
145+
end.tv_usec) - (start.tv_usec);
146+
printf("used time: %d\n", micros / 100);
147+
148+
// eval value
149+
std::vector<T> dx_vec;
150+
TensorToVector(*tensor_dx, ctx, &dx_vec);
151+
152+
float expected = 1.082964;
153+
for (uint32_t i = 0; i < dx_vec.size(); i++) {
154+
EXPECT_FLOAT_EQ(dx_vec[i], static_cast<T>(expected));
155+
}
156+
}
157+
158+
TEST(gelu, NPU_fp32) {
159+
f::Scope scope;
160+
p::NPUDeviceContext ctx(p::NPUPlace(0));
161+
Compare<float>(&scope, ctx);
162+
}
163+
164+
TEST(gelu_grad, NPU) {
165+
f::Scope scope;
166+
p::NPUDeviceContext ctx(p::NPUPlace(0));
167+
CompareGrad<float>(&scope, ctx);
168+
}
169+

0 commit comments

Comments
 (0)