Skip to content

Commit d5baea8

Browse files
committed
add momentum_op_npu and test
1 parent 1f28968 commit d5baea8

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <string>
15+
#include "paddle/fluid/operators/npu_op_runner.h"
16+
#include "paddle/fluid/operators/optimizers/sgd_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
class NPUMomentumOpKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext& ctx) const override {
25+
auto& dev_ctx = ctx.template device_context<platform::NPUDeviceContext>();
26+
27+
auto param = ctx.Input<framework::Tensor>("Param");
28+
auto velocity = ctx.Input<framework::Tensor>("Velocity");
29+
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
30+
auto grad = ctx.Input<framework::Tensor>("Grad");
31+
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
32+
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
33+
param_out->mutable_data<T>(ctx.GetPlace());
34+
velocity_out->mutable_data<T>(ctx.GetPlace());
35+
36+
T mu = static_cast<T>(ctx.Attr<float>("mu"));
37+
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
38+
39+
Tensor mu_tensor;
40+
mu_tensor.mutable_data<T>(framework::make_ddim({1}), ctx.GetPlace());
41+
FillNpuTensorWithConstant<T>(&mu_tensor, mu);
42+
framework::TensorCopy(*param, ctx.GetPlace(), dev_ctx, param_out);
43+
framework::TensorCopy(*velocity, ctx.GetPlace(), dev_ctx, velocity_out);
44+
const auto& runner = NpuOpRunner(
45+
"ApplyMomentum",
46+
{*param_out, *velocity_out, *learning_rate, *grad, mu_tensor},
47+
{*param_out}, {{"use_nesterov", use_nesterov}});
48+
auto stream = dev_ctx.stream();
49+
runner.Run(stream);
50+
}
51+
};
52+
} // namespace operators
53+
} // namespace paddle
54+
55+
namespace ops = paddle::operators;
56+
REGISTER_OP_NPU_KERNEL(momentum, ops::NPUMomentumOpKernel<float>);
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
from test_momentum_op import calculate_momentum_by_numpy
26+
27+
paddle.enable_static()
28+
SEED = 2021
29+
30+
31+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
32+
"core is not compiled with NPU")
33+
class TestMomentumOp1(OpTest):
34+
def set_npu(self):
35+
self.__class__.use_npu = True
36+
37+
def setUp(self):
38+
self.set_npu()
39+
self.op_type = "momentum"
40+
self.init_dtype()
41+
self.init_case()
42+
43+
param = np.random.random(self.shape).astype(self.dtype)
44+
grad = np.random.random(self.shape).astype(self.dtype)
45+
velocity = np.zeros(self.shape).astype(self.dtype)
46+
learning_rate = np.array([0.001]).astype(np.float32)
47+
mu = 0.0001
48+
49+
self.inputs = {
50+
'Param': param,
51+
'Grad': grad,
52+
'Velocity': velocity,
53+
'LearningRate': learning_rate
54+
}
55+
56+
self.attrs = {'mu': mu, 'use_nesterov': self.use_nesterov}
57+
58+
param_out, velocity_out = calculate_momentum_by_numpy(
59+
param=param,
60+
grad=grad,
61+
mu=mu,
62+
velocity=velocity,
63+
use_nesterov=self.use_nesterov,
64+
learning_rate=learning_rate)
65+
66+
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
67+
68+
def init_case(self):
69+
self.shape = (123, 321)
70+
self.use_nesterov = False
71+
72+
def init_dtype(self):
73+
self.dtype = np.float32
74+
75+
def test_check_output(self):
76+
self.check_output_with_place(core.NPUPlace(0))
77+
78+
79+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
80+
"core is not compiled with NPU")
81+
class TestMomentumOp2(TestMomentumOp1):
82+
def init_case(self):
83+
self.shape = (123, 321)
84+
self.use_nesterov = True
85+
86+
87+
if __name__ == "__main__":
88+
unittest.main()

0 commit comments

Comments
 (0)