Skip to content

Commit 1e95600

Browse files
authored
[NPU] add npu kernel for adam (#31644)
* add npu kernel for adam * refine code * disable test * modify atol
1 parent 795b0f9 commit 1e95600

File tree

2 files changed

+309
-0
lines changed

2 files changed

+309
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <memory>
16+
#include <string>
17+
18+
#include "paddle/fluid/operators/npu_op_runner.h"
19+
#include "paddle/fluid/operators/optimizers/adam_op.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
using LoDTensor = framework::LoDTensor;
26+
27+
template <typename DeviceContext, typename T>
28+
class AdamNPUKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
const auto* param_var = ctx.InputVar("Param");
32+
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
33+
platform::errors::InvalidArgument(
34+
"The Var(%s)'s type should be LoDTensor, "
35+
"but the received is %s",
36+
ctx.InputNames("Param").front(),
37+
framework::ToTypeName(param_var->Type())));
38+
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
39+
auto* param = ctx.Input<LoDTensor>("Param");
40+
auto* grad_var = ctx.InputVar("Grad");
41+
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
42+
platform::errors::InvalidArgument(
43+
"The Grad(%s)'s type should be LoDTensor, "
44+
"but the received is %s",
45+
ctx.InputNames("Grad").front(),
46+
framework::ToTypeName(param_var->Type())));
47+
auto* grad = ctx.Input<LoDTensor>("Grad");
48+
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
49+
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
50+
auto* lr = ctx.Input<LoDTensor>("LearningRate");
51+
52+
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
53+
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
54+
55+
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
56+
auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
57+
auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
58+
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
59+
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
60+
61+
param_out->mutable_data<T>(ctx.GetPlace());
62+
mom1_out->mutable_data<T>(ctx.GetPlace());
63+
mom2_out->mutable_data<T>(ctx.GetPlace());
64+
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
65+
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
66+
67+
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
68+
if (ctx.HasInput("Beta1Tensor")) {
69+
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
70+
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
71+
platform::errors::InvalidArgument(
72+
"Input(Beta1Tensor) size must be 1, but get %d",
73+
beta1_tensor->numel()));
74+
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
75+
}
76+
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
77+
if (ctx.HasInput("Beta2Tensor")) {
78+
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
79+
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
80+
platform::errors::InvalidArgument(
81+
"Input(Beta2Tensor) size must be 1, but get %d",
82+
beta2_tensor->numel()));
83+
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
84+
}
85+
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
86+
<< "beta2_pow.numel() : " << beta2_pow->numel();
87+
VLOG(3) << "param.numel(): " << param->numel();
88+
89+
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
90+
platform::errors::InvalidArgument(
91+
"beta1 pow output size should be 1, but received "
92+
"value is:%d.",
93+
beta1_pow_out->numel()));
94+
95+
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
96+
platform::errors::InvalidArgument(
97+
"beta2 pow output size should be 1, but received "
98+
"value is:%d.",
99+
beta2_pow_out->numel()));
100+
101+
// reshape
102+
Tensor beta1_tensor(framework::proto::VarType::FP32);
103+
beta1_tensor.mutable_data<float>({1}, ctx.GetPlace());
104+
TensorFromVector(std::vector<T>{beta1}, ctx.device_context(),
105+
&beta1_tensor);
106+
Tensor beta2_tensor(framework::proto::VarType::FP32);
107+
beta2_tensor.mutable_data<float>({1}, ctx.GetPlace());
108+
TensorFromVector(std::vector<T>{beta2}, ctx.device_context(),
109+
&beta2_tensor);
110+
111+
Tensor epsilon_tensor(framework::proto::VarType::FP32);
112+
epsilon_tensor.mutable_data<T>({1}, ctx.GetPlace());
113+
TensorFromVector(std::vector<T>{epsilon}, ctx.device_context(),
114+
&epsilon_tensor);
115+
auto stream =
116+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
117+
.stream();
118+
auto runner =
119+
NpuOpRunner("ApplyAdamD",
120+
{
121+
*param, *mom1, *mom2, *beta1_pow, *beta2_pow, *lr,
122+
beta1_tensor, beta2_tensor, epsilon_tensor, *grad,
123+
},
124+
{
125+
*param_out, *mom1_out, *mom2_out,
126+
},
127+
{});
128+
runner.Run(stream);
129+
130+
// NOTE(zhiqiu): ApplyAdamD updates params inplace, so
131+
// if param and param_out is not same, we need to do copy.
132+
if (param_out->data<T>() != param->data<T>()) {
133+
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
134+
framework::TensorCopySync(*param, ctx.GetPlace(), param_out);
135+
}
136+
if (mom1_out->data<T>() != mom1->data<T>()) {
137+
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
138+
framework::TensorCopySync(*mom1, ctx.GetPlace(), mom1_out);
139+
}
140+
if (mom2_out->data<T>() != mom2->data<T>()) {
141+
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
142+
framework::TensorCopySync(*mom2, ctx.GetPlace(), mom2_out);
143+
}
144+
auto runner_m1 =
145+
NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {});
146+
runner_m1.Run(stream);
147+
auto runner_m2 =
148+
NpuOpRunner("Mul", {*beta2_pow, beta2_tensor}, {*beta2_pow_out}, {});
149+
runner_m2.Run(stream);
150+
}
151+
};
152+
153+
} // namespace operators
154+
} // namespace paddle
155+
156+
namespace ops = paddle::operators;
157+
158+
REGISTER_OP_NPU_KERNEL(
159+
adam, ops::AdamNPUKernel<paddle::platform::NPUDeviceContext, float>,
160+
ops::AdamNPUKernel<paddle::platform::NPUDeviceContext,
161+
paddle::platform::float16>);
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import unittest
17+
import sys
18+
sys.path.append("..")
19+
from op_test import OpTest
20+
import paddle
21+
import paddle.fluid as fluid
22+
from test_adam_op import adam_step
23+
24+
paddle.enable_static()
25+
SEED = 2021
26+
27+
28+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
29+
"core is not compiled with NPU")
30+
class TestSGD(OpTest):
31+
def setUp(self):
32+
self.set_npu()
33+
self.place = paddle.NPUPlace(0)
34+
self.op_type = "adam"
35+
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
36+
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
37+
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
38+
# The second moment is positive
39+
moment2 = np.random.random((102, 105)).astype("float32")
40+
41+
learning_rate = 0.004
42+
beta1 = 0.78
43+
beta2 = 0.836
44+
epsilon = 1e-4
45+
beta1_pow = beta1**10
46+
beta2_pow = beta2**10
47+
48+
self.inputs = {
49+
'Param': param,
50+
'Grad': grad,
51+
'Moment1': moment1,
52+
'Moment2': moment2,
53+
'LearningRate': np.array([learning_rate]).astype("float32"),
54+
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
55+
'Beta2Pow': np.array([beta2_pow]).astype("float32")
56+
}
57+
58+
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
59+
60+
param_out, moment1_out, \
61+
moment2_out = adam_step(self.inputs, self.attrs)
62+
63+
self.outputs = {
64+
'Moment1Out': moment1_out,
65+
'Moment2Out': moment2_out,
66+
'ParamOut': param_out,
67+
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
68+
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
69+
}
70+
71+
def set_npu(self):
72+
self.__class__.use_npu = True
73+
74+
def init_dtype(self):
75+
self.dtype = np.float32
76+
77+
def test_check_output(self):
78+
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
79+
80+
81+
'''
82+
# TODO(zhiqiu): The following test may let 0-3 card down.
83+
# we need to analyze it and open it.
84+
85+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
86+
"core is not compiled with NPU")
87+
class TestNet(unittest.TestCase):
88+
def _test(self, run_npu=True):
89+
main_prog = paddle.static.Program()
90+
startup_prog = paddle.static.Program()
91+
main_prog.random_seed = SEED
92+
startup_prog.random_seed = SEED
93+
np.random.seed(SEED)
94+
95+
a_np = np.random.random(size=(32, 32)).astype('float32')
96+
b_np = np.random.random(size=(32, 32)).astype('float32')
97+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
98+
99+
with paddle.static.program_guard(main_prog, startup_prog):
100+
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
101+
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
102+
label = paddle.static.data(
103+
name="label", shape=[32, 1], dtype='int64')
104+
105+
sum = paddle.add(a, b)
106+
z = paddle.pow(sum, 2.0)
107+
108+
fc_1 = fluid.layers.fc(input=z, size=128)
109+
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
110+
111+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
112+
loss = fluid.layers.reduce_mean(cost)
113+
adam = fluid.optimizer.Adam(learning_rate=0.01)
114+
adam.minimize(loss)
115+
116+
if run_npu:
117+
place = paddle.NPUPlace(0)
118+
else:
119+
place = paddle.CPUPlace()
120+
121+
exe = paddle.static.Executor(place)
122+
exe.run(startup_prog)
123+
124+
print("Start run on {}".format(place))
125+
for epoch in range(100):
126+
127+
pred_res, loss_res = exe.run(
128+
main_prog,
129+
feed={"a": a_np,
130+
"b": b_np,
131+
"label": label_np},
132+
fetch_list=[prediction, loss])
133+
if epoch % 10 == 0:
134+
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
135+
epoch, pred_res[0], loss_res))
136+
137+
return pred_res, loss_res
138+
139+
def test_npu(self):
140+
cpu_pred, cpu_loss = self._test(False)
141+
npu_pred, npu_loss = self._test(True)
142+
143+
self.assertTrue(np.allclose(npu_pred, cpu_pred))
144+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
145+
'''
146+
147+
if __name__ == '__main__':
148+
unittest.main()

0 commit comments

Comments
 (0)