@@ -23,33 +23,33 @@ class AdagradOp : public framework::OperatorWithKernel {
2323
2424 protected:
2525 void InferShape (framework::InferShapeContextBase *ctx) const override {
26- PADDLE_ENFORCE (ctx->HasInput (" param " ),
27- " Input(param ) of AdagradOp should not be null." );
28- PADDLE_ENFORCE (ctx->HasInput (" grad " ),
29- " Input(grad ) of AdagradOp should not be null." );
30- PADDLE_ENFORCE (ctx->HasInput (" moment " ),
31- " Input(moment ) of AdagradOp should not be null." );
32- PADDLE_ENFORCE (ctx->HasInput (" learning_rate " ),
33- " Input(learning_rate ) of AdagradOp should not be null." );
34-
35- PADDLE_ENFORCE (ctx->HasOutput (" param_out " ),
36- " Output(param_out ) of AdagradOp should not be null." );
37- PADDLE_ENFORCE (ctx->HasOutput (" moment_out " ),
38- " Output(moment_out ) of AdagradOp should not be null." );
39-
40- auto lr_dims = ctx->GetInputDim (" learning_rate " );
26+ PADDLE_ENFORCE (ctx->HasInput (" Param " ),
27+ " Input(Param ) of AdagradOp should not be null." );
28+ PADDLE_ENFORCE (ctx->HasInput (" Grad " ),
29+ " Input(Grad ) of AdagradOp should not be null." );
30+ PADDLE_ENFORCE (ctx->HasInput (" Moment " ),
31+ " Input(Moment ) of AdagradOp should not be null." );
32+ PADDLE_ENFORCE (ctx->HasInput (" LearningRate " ),
33+ " Input(LearningRate ) of AdagradOp should not be null." );
34+
35+ PADDLE_ENFORCE (ctx->HasOutput (" ParamOut " ),
36+ " Output(ParamOut ) of AdagradOp should not be null." );
37+ PADDLE_ENFORCE (ctx->HasOutput (" MomentOut " ),
38+ " Output(MomentOut ) of AdagradOp should not be null." );
39+
40+ auto lr_dims = ctx->GetInputDim (" LearningRate " );
4141 PADDLE_ENFORCE_EQ (framework::product (lr_dims), 1 ,
42- " learning_rate should have one element" );
43- auto param_dim = ctx->GetInputDim (" param " );
42+ " LearningRate should have one element" );
43+ auto param_dims = ctx->GetInputDim (" Param " );
4444 PADDLE_ENFORCE_EQ (
45- param_dim , ctx->GetInputDim (" grad " ),
46- " Param and grad input of AdagradOp should have the same dimension." );
45+ param_dims , ctx->GetInputDim (" Grad " ),
46+ " Param and Grad input of AdagradOp should have the same dimension." );
4747 PADDLE_ENFORCE_EQ (
48- param_dim , ctx->GetInputDim (" moment " ),
49- " Param and moment input of AdagradOp should have the same dimension." );
48+ param_dims , ctx->GetInputDim (" Moment " ),
49+ " Param and Moment input of AdagradOp should have the same dimension." );
5050
51- ctx->SetOutputDim (" param_out " , param_dim );
52- ctx->SetOutputDim (" moment_out " , param_dim );
51+ ctx->SetOutputDim (" ParamOut " , param_dims );
52+ ctx->SetOutputDim (" MomentOut " , param_dims );
5353 }
5454};
5555
@@ -58,15 +58,18 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
5858 AdagradOpMaker (framework::OpProto *proto,
5959 framework::OpAttrChecker *op_checker)
6060 : OpProtoAndCheckerMaker(proto, op_checker) {
61- AddInput (" param" , " Input parameter" );
62- AddInput (" grad" , " Input gradient" );
63- AddInput (" moment" , " Second moment" );
64- AddInput (" learning_rate" , " learning rate of adagrad" );
65-
66- AddOutput (" param_out" , " Output parameter" );
67- AddOutput (" moment_out" , " Output second moment" );
68-
69- AddAttr<float >(" epsilon" , " Constant for numerical stability" );
61+ AddInput (" Param" , " (Tensor) Input parameter" );
62+ AddInput (" Grad" , " (Tensor) Input gradient" );
63+ AddInput (" Moment" , " (Tensor) Second moment" );
64+ AddInput (" LearningRate" , " (Tensor) Learning rate" );
65+
66+ AddOutput (" ParamOut" , " (Tensor) Output parameter" );
67+ AddOutput (" MomentOut" , " (Tensor) Output second moment" );
68+
69+ AddAttr<float >(" epsilon" ,
70+ " (float, default 1.0e-6) "
71+ " Constant for numerical stability" )
72+ .SetDefault (1 .0e-6f );
7073 AddComment (R"DOC(
7174
7275Adaptive Gradient Algorithm (Adagrad).
0 commit comments