Skip to content

Commit d3b8bff

Browse files
authored
Implementing the Decayed Adagrad optimizer operator (#4645)
* Implementing the DecayedAdagrad optimizer step operator * implementing DecayedAdagrad operator * remove file * small fix
1 parent 2daba04 commit d3b8bff

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/decayed_adagrad_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class DecayedAdagradOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContextBase *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Param"),
27+
"Input(Param) of DecayedAdagradOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Grad"),
29+
"Input(Grad) of DecayedAdagradOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Moment"),
31+
"Input(Moment) of DecayedAdagradOp should not be null.");
32+
PADDLE_ENFORCE(
33+
ctx->HasInput("LearningRate"),
34+
"Input(LearningRate) of DecayedAdagradOp should not be null.");
35+
36+
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
37+
"Output(ParamOut) of DecayedAdagradOp should not be null.");
38+
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
39+
"Output(MomentOut) of DecayedAdagradOp should not be null.");
40+
41+
auto lr_dims = ctx->GetInputDim("LearningRate");
42+
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
43+
"LearningRate should have one element");
44+
auto param_dims = ctx->GetInputDim("Param");
45+
PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Grad"),
46+
"Param and Grad input of DecayedAdagradOp should have "
47+
"the same dimension.");
48+
PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Moment"),
49+
"Param and Moment input of DecayedAdagradOp should have "
50+
"the same dimension.");
51+
52+
ctx->SetOutputDim("ParamOut", param_dims);
53+
ctx->SetOutputDim("MomentOut", param_dims);
54+
}
55+
};
56+
57+
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
58+
public:
59+
DecayedAdagradOpMaker(framework::OpProto *proto,
60+
framework::OpAttrChecker *op_checker)
61+
: OpProtoAndCheckerMaker(proto, op_checker) {
62+
AddInput("Param", "(Tensor) Input parameter");
63+
AddInput("Grad", "(Tensor) Input gradient");
64+
AddInput("Moment", "(Tensor) Second moment");
65+
AddInput("LearningRate", "(Tensor) Learning rate");
66+
67+
AddOutput("ParamOut", "(Tensor) Output parameter");
68+
AddOutput("MomentOut", "(Tensor) Output second moment");
69+
70+
AddAttr<float>("decay",
71+
"(float, default 0.95) "
72+
"Discounting factor for coming gradient")
73+
.SetDefault(0.95);
74+
AddAttr<float>("epsilon",
75+
"(float, default 1.0e-6) "
76+
"Constant for numerical stability")
77+
.SetDefault(1.0e-6f);
78+
AddComment(R"DOC(
79+
80+
Decayed Adagrad
81+
82+
moment_out = decay * moment + (1 - decay) * grad * grad
83+
param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon)
84+
85+
)DOC");
86+
}
87+
};
88+
} // namespace operators
89+
} // namespace paddle
90+
91+
namespace ops = paddle::operators;
92+
REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp,
93+
ops::DecayedAdagradOpMaker);
94+
REGISTER_OP_CPU_KERNEL(
95+
decayed_adagrad,
96+
ops::DecayedAdagradOpKernel<paddle::platform::CPUPlace, float>);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/decayed_adagrad_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(
20+
decayed_adagrad,
21+
ops::DecayedAdagradOpKernel<paddle::platform::GPUPlace, float>);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename Place, typename T>
23+
class DecayedAdagradOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
27+
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
28+
29+
param_out_tensor->mutable_data<T>(ctx.GetPlace());
30+
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
31+
32+
float decay = ctx.Attr<float>("decay");
33+
float epsilon = ctx.Attr<float>("epsilon");
34+
35+
auto param = framework::EigenVector<T>::Flatten(
36+
*ctx.Input<framework::Tensor>("Param"));
37+
auto grad = framework::EigenVector<T>::Flatten(
38+
*ctx.Input<framework::Tensor>("Grad"));
39+
auto moment = framework::EigenVector<T>::Flatten(
40+
*ctx.Input<framework::Tensor>("Moment"));
41+
auto lr = framework::EigenVector<T>::Flatten(
42+
*ctx.Input<framework::Tensor>("LearningRate"));
43+
44+
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
45+
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
46+
auto place = ctx.GetEigenDevice<Place>();
47+
48+
moment_out.device(place) = decay * moment + (1 - decay) * grad * grad;
49+
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
50+
param_out.device(place) =
51+
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
52+
}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
class TestDecayedAdagradOp1(OpTest):
7+
''' Test DecayedAdagrad operator with explicit attributes
8+
'''
9+
10+
def setUp(self):
11+
self.op_type = "decayed_adagrad"
12+
13+
param = np.random.random((123, 321)).astype("float32")
14+
grad = np.random.random((123, 321)).astype("float32")
15+
moment = np.zeros((123, 321)).astype("float32")
16+
lr = 0.01
17+
decay = 0.80
18+
epsilon = 1e-8
19+
20+
self.inputs = {
21+
'Param': param,
22+
'Grad': grad,
23+
'Moment': moment,
24+
'LearningRate': np.array([lr]).astype("float32")
25+
}
26+
27+
self.attrs = {'decay': decay, 'epsilon': epsilon}
28+
29+
moment_out = decay * moment + (1 - decay) * grad * grad
30+
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
31+
32+
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
33+
34+
def test_check_output(self):
35+
self.check_output()
36+
37+
38+
class TestDecayedAdagradOp2(OpTest):
39+
''' Test DecayedAdagrad operator with default attributes
40+
'''
41+
42+
def setUp(self):
43+
self.op_type = "decayed_adagrad"
44+
45+
param = np.random.random((123, 321)).astype("float32")
46+
grad = np.random.random((123, 321)).astype("float32")
47+
moment = np.zeros((123, 321)).astype("float32")
48+
lr = 0.01
49+
decay = 0.95
50+
epsilon = 1e-6
51+
52+
self.inputs = {
53+
'Param': param,
54+
'Grad': grad,
55+
'Moment': moment,
56+
'LearningRate': np.array([lr]).astype("float32")
57+
}
58+
59+
self.attrs = {'decay': decay, 'epsilon': epsilon}
60+
61+
moment_out = decay * moment + (1 - decay) * grad * grad
62+
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
63+
64+
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
65+
66+
def test_check_output(self):
67+
self.check_output()
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

0 commit comments

Comments
 (0)