@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include " paddle/fluid/operators/optimizers/adadelta_op.h"
15+ #include " paddle/fluid/framework/infershape_utils.h"
16+ #include " paddle/fluid/framework/op_registry.h"
17+ #include " paddle/phi/core/infermeta_utils.h"
18+ #include " paddle/phi/infermeta/multiary.h"
1619
1720namespace paddle {
1821namespace operators {
@@ -23,77 +26,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
2326 public:
2427 using framework::OperatorWithKernel::OperatorWithKernel;
2528
26- void InferShape (framework::InferShapeContext *ctx) const override {
27- PADDLE_ENFORCE_EQ (ctx->HasInput (" Param" ), true ,
28- platform::errors::InvalidArgument (
29- " Input(Param) of AdadeltaOp should not be null." ));
30- PADDLE_ENFORCE_EQ (ctx->HasInput (" Grad" ), true ,
31- platform::errors::InvalidArgument (
32- " Input(Grad) of AdadeltaOp should not be null." ));
33- PADDLE_ENFORCE_EQ (
34- ctx->HasInput (" AvgSquaredGrad" ), true ,
35- platform::errors::InvalidArgument (
36- " Input(AvgSquaredGrad) of AdadeltaOp should not be null." ));
37- PADDLE_ENFORCE_EQ (
38- ctx->HasInput (" AvgSquaredUpdate" ), true ,
39- platform::errors::InvalidArgument (
40- " Input(AvgSquaredUpdate) of AdadeltaOp should not be null." ));
41- PADDLE_ENFORCE_EQ (
42- ctx->GetInputsVarType (" Param" ).front () ==
43- framework::proto::VarType::LOD_TENSOR,
44- true ,
45- platform::errors::InvalidArgument (
46- " The input var's type should be LoDTensor, but the received is %s" ,
47- ctx->Inputs (" Param" ).front (),
48- ctx->GetInputsVarType (" Param" ).front ()));
49- PADDLE_ENFORCE_EQ (
50- ctx->GetInputsVarType (" Grad" ).front () ==
51- framework::proto::VarType::LOD_TENSOR,
52- true ,
53- platform::errors::InvalidArgument (
54- " The input var's type should be LoDTensor, but the received is %s" ,
55- ctx->Inputs (" Grad" ).front (),
56- ctx->GetInputsVarType (" Grad" ).front ()));
57-
58- PADDLE_ENFORCE_EQ (
59- ctx->HasOutput (" ParamOut" ), true ,
60- platform::errors::InvalidArgument (
61- " Output(ParamOut) of AdadeltaOp should not be null." ));
62- PADDLE_ENFORCE_EQ (
63- ctx->HasOutput (" AvgSquaredGradOut" ), true ,
64- platform::errors::InvalidArgument (
65- " Output(AvgSquaredGradOut) of AdadeltaOp should not be null." ));
66- PADDLE_ENFORCE_EQ (
67- ctx->HasOutput (" AvgSquaredUpdateOut" ), true ,
68- platform::errors::InvalidArgument (
69- " Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null." ));
70-
71- auto param_dim = ctx->GetInputDim (" Param" );
72- PADDLE_ENFORCE_EQ (
73- param_dim, ctx->GetInputDim (" Grad" ),
74- platform::errors::InvalidArgument (
75- " Param and grad input of AdadeltaOp should have same dimension." ));
76- PADDLE_ENFORCE_NE (
77- phi::product (ctx->GetInputDim (" AvgSquaredGrad" )), 0 ,
78- platform::errors::InvalidArgument (
79- " Maybe the Input variable AvgSquaredGrad has not "
80- " been initialized. You may need to confirm if you put "
81- " exe.run(startup_program) after optimizer.minimize "
82- " function." ));
83- PADDLE_ENFORCE_EQ (param_dim, ctx->GetInputDim (" AvgSquaredGrad" ),
84- platform::errors::InvalidArgument (
85- " Param and AvgSquaredGrad input of AdadeltaOp "
86- " should have same dimension" ));
87- PADDLE_ENFORCE_EQ (param_dim, ctx->GetInputDim (" AvgSquaredUpdate" ),
88- platform::errors::InvalidArgument (
89- " Param and AvgSquaredUpdate input of AdadeltaOp "
90- " should have same dimension" ));
91-
92- ctx->SetOutputDim (" ParamOut" , param_dim);
93- ctx->SetOutputDim (" AvgSquaredGradOut" , param_dim);
94- ctx->SetOutputDim (" AvgSquaredUpdateOut" , param_dim);
95- }
96-
9729 framework::OpKernelType GetExpectedKernelType (
9830 const framework::ExecutionContext &ctx) const override {
9931 return framework::OpKernelType (
@@ -149,7 +81,11 @@ param\_out = param + param\_update
14981} // namespace paddle
15082
15183namespace ops = paddle::operators;
152- REGISTER_OP_WITHOUT_GRADIENT (adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
153- REGISTER_OP_CPU_KERNEL (
154- adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float >,
155- ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double >);
84+ namespace ops = paddle::operators;
85+ DELCARE_INFER_SHAPE_FUNCTOR (adadelta, AdadeltaInferMetaFunctor,
86+ PT_INFER_META (phi::AdadeltaInferMeta));
87+ REGISTER_OPERATOR (
88+ adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker,
89+ paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
90+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
91+ AdadeltaInferMetaFunctor);
0 commit comments