Skip to content

Commit 087adda

Browse files
authored
Merge pull request #4558 from kexinzhao/adagrad_op
Implementing the Adagrad optimizer step operator
2 parents 48f98a6 + 78f4c80 commit 087adda

File tree

4 files changed

+237
-0
lines changed

4 files changed

+237
-0
lines changed

paddle/operators/adagrad_op.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/adagrad_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AdagradOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContextBase *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Param"),
27+
"Input(Param) of AdagradOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Grad"),
29+
"Input(Grad) of AdagradOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Moment"),
31+
"Input(Moment) of AdagradOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
33+
"Input(LearningRate) of AdagradOp should not be null.");
34+
35+
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
36+
"Output(ParamOut) of AdagradOp should not be null.");
37+
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
38+
"Output(MomentOut) of AdagradOp should not be null.");
39+
40+
auto lr_dims = ctx->GetInputDim("LearningRate");
41+
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
42+
"LearningRate should have one element");
43+
auto param_dims = ctx->GetInputDim("Param");
44+
PADDLE_ENFORCE_EQ(
45+
param_dims, ctx->GetInputDim("Grad"),
46+
"Param and Grad input of AdagradOp should have the same dimension.");
47+
PADDLE_ENFORCE_EQ(
48+
param_dims, ctx->GetInputDim("Moment"),
49+
"Param and Moment input of AdagradOp should have the same dimension.");
50+
51+
ctx->SetOutputDim("ParamOut", param_dims);
52+
ctx->SetOutputDim("MomentOut", param_dims);
53+
}
54+
};
55+
56+
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
57+
public:
58+
AdagradOpMaker(framework::OpProto *proto,
59+
framework::OpAttrChecker *op_checker)
60+
: OpProtoAndCheckerMaker(proto, op_checker) {
61+
AddInput("Param", "(Tensor) Input parameter");
62+
AddInput("Grad", "(Tensor) Input gradient");
63+
AddInput("Moment", "(Tensor) Second moment");
64+
AddInput("LearningRate", "(Tensor) Learning rate");
65+
66+
AddOutput("ParamOut", "(Tensor) Output parameter");
67+
AddOutput("MomentOut", "(Tensor) Output second moment");
68+
69+
AddAttr<float>("epsilon",
70+
"(float, default 1.0e-6) "
71+
"Constant for numerical stability")
72+
.SetDefault(1.0e-6f);
73+
AddComment(R"DOC(
74+
75+
Adaptive Gradient Algorithm (Adagrad).
76+
77+
moment_out = moment + grad * grad
78+
param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon)
79+
80+
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
81+
does not have the epsilon attribute. It is added here for numerical stability
82+
by avoiding division by zero.
83+
84+
)DOC");
85+
}
86+
};
87+
} // namespace operators
88+
} // namespace paddle
89+
90+
namespace ops = paddle::operators;
91+
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
92+
REGISTER_OP_CPU_KERNEL(adagrad,
93+
ops::AdagradOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/adagrad_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/adagrad_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(adagrad,
20+
ops::AdagradOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/adagrad_op.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename Place, typename T>
23+
class AdagradOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
27+
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
28+
29+
param_out_tensor->mutable_data<T>(ctx.GetPlace());
30+
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
31+
32+
float epsilon = ctx.Attr<float>("epsilon");
33+
34+
auto param = framework::EigenVector<T>::Flatten(
35+
*ctx.Input<framework::Tensor>("Param"));
36+
auto grad = framework::EigenVector<T>::Flatten(
37+
*ctx.Input<framework::Tensor>("Grad"));
38+
auto moment = framework::EigenVector<T>::Flatten(
39+
*ctx.Input<framework::Tensor>("Moment"));
40+
auto lr = framework::EigenVector<T>::Flatten(
41+
*ctx.Input<framework::Tensor>("LearningRate"));
42+
43+
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
44+
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
45+
auto place = ctx.GetEigenDevice<Place>();
46+
47+
moment_out.device(place) = moment + grad * grad;
48+
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
49+
param_out.device(place) =
50+
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
51+
}
52+
};
53+
54+
} // namespace operators
55+
} // namespace paddle
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
class TestAdagradOp1(OpTest):
7+
''' Test Adagrad operator with explicit attributes
8+
'''
9+
10+
def setUp(self):
11+
self.op_type = "adagrad"
12+
13+
param = np.random.random((123, 321)).astype("float32")
14+
grad = np.random.random((123, 321)).astype("float32")
15+
moment = np.zeros((123, 321)).astype("float32")
16+
lr = 0.01
17+
epsilon = 1e-8
18+
19+
self.inputs = {
20+
'Param': param,
21+
'Grad': grad,
22+
'Moment': moment,
23+
'LearningRate': np.array([lr]).astype("float32")
24+
}
25+
26+
self.attrs = {'epsilon': epsilon}
27+
28+
moment_out = moment + grad * grad
29+
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
30+
31+
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
32+
33+
def test_check_output(self):
34+
self.check_output()
35+
36+
37+
class TestAdagradOp2(OpTest):
38+
''' Test Adagrad operator with default attributes
39+
'''
40+
41+
def setUp(self):
42+
self.op_type = "adagrad"
43+
44+
param = np.random.random((123, 321)).astype("float32")
45+
grad = np.random.random((123, 321)).astype("float32")
46+
moment = np.zeros((123, 321)).astype("float32")
47+
lr = 0.01
48+
epsilon = 1e-6
49+
50+
self.inputs = {
51+
'Param': param,
52+
'Grad': grad,
53+
'Moment': moment,
54+
'LearningRate': np.array([lr]).astype("float32")
55+
}
56+
57+
self.attrs = {'epsilon': epsilon}
58+
59+
moment_out = moment + grad * grad
60+
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
61+
62+
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
63+
64+
def test_check_output(self):
65+
self.check_output()
66+
67+
68+
if __name__ == "__main__":
69+
unittest.main()

0 commit comments

Comments
 (0)