Skip to content

Commit 6843edc

Browse files
committed
Add adadelta kernel
1 parent 953252e commit 6843edc

File tree

9 files changed

+201
-179
lines changed

9 files changed

+201
-179
lines changed

paddle/fluid/operators/optimizers/adadelta_op.cc

Lines changed: 12 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/optimizers/adadelta_op.h"
15+
#include "paddle/fluid/framework/infershape_utils.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/multiary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -23,77 +26,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
2326
public:
2427
using framework::OperatorWithKernel::OperatorWithKernel;
2528

26-
void InferShape(framework::InferShapeContext *ctx) const override {
27-
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
28-
platform::errors::InvalidArgument(
29-
"Input(Param) of AdadeltaOp should not be null."));
30-
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
31-
platform::errors::InvalidArgument(
32-
"Input(Grad) of AdadeltaOp should not be null."));
33-
PADDLE_ENFORCE_EQ(
34-
ctx->HasInput("AvgSquaredGrad"), true,
35-
platform::errors::InvalidArgument(
36-
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."));
37-
PADDLE_ENFORCE_EQ(
38-
ctx->HasInput("AvgSquaredUpdate"), true,
39-
platform::errors::InvalidArgument(
40-
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."));
41-
PADDLE_ENFORCE_EQ(
42-
ctx->GetInputsVarType("Param").front() ==
43-
framework::proto::VarType::LOD_TENSOR,
44-
true,
45-
platform::errors::InvalidArgument(
46-
"The input var's type should be LoDTensor, but the received is %s",
47-
ctx->Inputs("Param").front(),
48-
ctx->GetInputsVarType("Param").front()));
49-
PADDLE_ENFORCE_EQ(
50-
ctx->GetInputsVarType("Grad").front() ==
51-
framework::proto::VarType::LOD_TENSOR,
52-
true,
53-
platform::errors::InvalidArgument(
54-
"The input var's type should be LoDTensor, but the received is %s",
55-
ctx->Inputs("Grad").front(),
56-
ctx->GetInputsVarType("Grad").front()));
57-
58-
PADDLE_ENFORCE_EQ(
59-
ctx->HasOutput("ParamOut"), true,
60-
platform::errors::InvalidArgument(
61-
"Output(ParamOut) of AdadeltaOp should not be null."));
62-
PADDLE_ENFORCE_EQ(
63-
ctx->HasOutput("AvgSquaredGradOut"), true,
64-
platform::errors::InvalidArgument(
65-
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null."));
66-
PADDLE_ENFORCE_EQ(
67-
ctx->HasOutput("AvgSquaredUpdateOut"), true,
68-
platform::errors::InvalidArgument(
69-
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null."));
70-
71-
auto param_dim = ctx->GetInputDim("Param");
72-
PADDLE_ENFORCE_EQ(
73-
param_dim, ctx->GetInputDim("Grad"),
74-
platform::errors::InvalidArgument(
75-
"Param and grad input of AdadeltaOp should have same dimension."));
76-
PADDLE_ENFORCE_NE(
77-
phi::product(ctx->GetInputDim("AvgSquaredGrad")), 0,
78-
platform::errors::InvalidArgument(
79-
"Maybe the Input variable AvgSquaredGrad has not "
80-
"been initialized. You may need to confirm if you put "
81-
"exe.run(startup_program) after optimizer.minimize "
82-
"function."));
83-
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
84-
platform::errors::InvalidArgument(
85-
"Param and AvgSquaredGrad input of AdadeltaOp "
86-
"should have same dimension"));
87-
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
88-
platform::errors::InvalidArgument(
89-
"Param and AvgSquaredUpdate input of AdadeltaOp "
90-
"should have same dimension"));
91-
92-
ctx->SetOutputDim("ParamOut", param_dim);
93-
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
94-
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
95-
}
96-
9729
framework::OpKernelType GetExpectedKernelType(
9830
const framework::ExecutionContext &ctx) const override {
9931
return framework::OpKernelType(
@@ -149,7 +81,11 @@ param\_out = param + param\_update
14981
} // namespace paddle
15082

15183
namespace ops = paddle::operators;
152-
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
153-
REGISTER_OP_CPU_KERNEL(
154-
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float>,
155-
ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double>);
84+
namespace ops = paddle::operators;
85+
DELCARE_INFER_SHAPE_FUNCTOR(adadelta, AdadeltaInferMetaFunctor,
86+
PT_INFER_META(phi::AdadeltaInferMeta));
87+
REGISTER_OPERATOR(
88+
adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker,
89+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
90+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
91+
AdadeltaInferMetaFunctor);

paddle/fluid/operators/optimizers/adadelta_op.cu

Lines changed: 0 additions & 19 deletions
This file was deleted.

paddle/fluid/operators/optimizers/adadelta_op.h

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/phi/infermeta/multiary.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,42 @@ void AdamaxInferMeta(const MetaTensor& param,
8484
inf_norm_out->set_dtype(inf_norm.dtype());
8585
}
8686

87+
void AdadeltaInferMeta(const MetaTensor& param,
88+
const MetaTensor& grad,
89+
const MetaTensor& avg_squared_grad,
90+
const MetaTensor& avg_squared_update,
91+
float rho,
92+
float epsilon,
93+
MetaTensor* param_out,
94+
MetaTensor* avg_squared_grad_out,
95+
MetaTensor* avg_squared_update_out) {
96+
auto param_dims = param.dims();
97+
PADDLE_ENFORCE_EQ(
98+
param_dims,
99+
grad.dims(),
100+
errors::InvalidArgument(
101+
"Param and grad input of AdadeltaOp should have same dimension."));
102+
PADDLE_ENFORCE_EQ(
103+
param_dims,
104+
avg_squared_grad.dims(),
105+
errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp "
106+
"should have same dimension"));
107+
PADDLE_ENFORCE_EQ(
108+
param_dims,
109+
avg_squared_update.dims(),
110+
errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp "
111+
"should have same dimension"));
112+
113+
param_out->set_dims(param_dims);
114+
param_out->set_dtype(param.dtype());
115+
116+
avg_squared_grad_out->set_dims(param_dims);
117+
avg_squared_grad_out->set_dtype(avg_squared_grad.dtype());
118+
119+
avg_squared_update_out->set_dims(param_dims);
120+
avg_squared_update_out->set_dtype(avg_squared_update.dtype());
121+
}
122+
87123
void BilinearTensorProductInferMeta(const MetaTensor& x,
88124
const MetaTensor& y,
89125
const MetaTensor& weight,

paddle/phi/infermeta/multiary.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,15 @@ void AdamaxInferMeta(const MetaTensor& param,
5252
MetaTensor* param_out,
5353
MetaTensor* moment_out,
5454
MetaTensor* inf_norm_out);
55+
56+
void AdadeltaInferMeta(const MetaTensor& param,
57+
const MetaTensor& grad,
58+
const MetaTensor& avg_squared_grad,
59+
const MetaTensor& avg_squared_update,
60+
float rho,
61+
float epsilon,
62+
MetaTensor* param_out,
63+
MetaTensor* avg_squared_grad_out,
64+
MetaTensor* avg_squared_update_out);
65+
5566
} // namespace phi
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void AdadeltaKernel(const Context& dev_ctx,
23+
const DenseTensor& param,
24+
const DenseTensor& grad,
25+
const DenseTensor& avg_squared_grad,
26+
const DenseTensor& avg_squared_update,
27+
float rho,
28+
float epsilon,
29+
DenseTensor* param_out,
30+
DenseTensor* avg_squared_grad_out,
31+
DenseTensor* avg_squared_update_out);
32+
33+
} // namespace phi
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/adadelta_kernel.h"
16+
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(
22+
adadelta, CPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/adadelta_kernel.h"
16+
17+
#include "paddle/phi/backends/gpu/gpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(
22+
adadelta, GPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {}

0 commit comments

Comments
 (0)