Skip to content

Commit 19e471c

Browse files
authored
Merge pull request #4 from veyron95/ops_derivative
Ops derivative
2 parents d6e771e + 4109fc5 commit 19e471c

5 files changed

Lines changed: 272 additions & 158 deletions

File tree

paddle/fluid/operators/activation_op.cc

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,10 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
771771
ctx->ShareDim("Out", "DDOut");
772772
ctx->ShareLoD("Out", "DDOut");
773773
}
774+
if (ctx->HasOutput("DOutNew")) {
775+
ctx->ShareDim("Out", "DOutNew");
776+
ctx->ShareLoD("Out", "DOutNew");
777+
}
774778
}
775779
}
776780

@@ -809,12 +813,12 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
809813
};
810814

811815
template <ActBwdOpFwdDeps kDepValue>
812-
class ActivationOpTribleGrad : public framework::OperatorWithKernel {
816+
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
813817
public:
814818
using framework::OperatorWithKernel::OperatorWithKernel;
815819

816820
void InferShape(framework::InferShapeContext* ctx) const override {
817-
VLOG(3) << "=========== in ActivationOpTribleGrad =========";
821+
VLOG(3) << "=========== in ActivationOpTripleGrad =========";
818822
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
819823
if (ctx->HasOutput("DX")) {
820824
ctx->ShareDim("X", "DX");
@@ -874,15 +878,15 @@ class SigmoidDoubleGradMaker
874878
};
875879

876880
template <typename T>
877-
class SigmoidTribleGradMaker
881+
class SigmoidTripleGradMaker
878882
: public ::paddle::framework::SingleGradOpMaker<T> {
879883
public:
880884
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
881885

882886
protected:
883887
void Apply(GradOpPtr<T> op) const override {
884-
VLOG(3) << "=========== in SigmoidTribleGradMaker =========";
885-
op->SetType("sigmoid_trible_grad");
888+
VLOG(3) << "=========== in SigmoidTripleGradMaker =========";
889+
op->SetType("sigmoid_triple_grad");
886890
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
887891
// D_OutNew, D_DOut, D_DDx // output
888892
// input1: Out
@@ -1078,7 +1082,7 @@ DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
10781082
framework::GradVarName("X")}); // dx
10791083
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
10801084
{"DDX", "DDOut"});
1081-
DECLARE_INPLACE_OP_INFERER(ActivationTribleGradOpInplaceInferer,
1085+
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
10821086
{"DDX", "D_DOut"});
10831087

10841088
template <typename T>
@@ -1209,14 +1213,14 @@ REGISTER_OPERATOR(sigmoid_grad_grad,
12091213
ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<
12101214
float>::FwdDeps()>, // 应该是 SigmoidGradGradFunctor
12111215
ops::ActivationDoubleGradOpInplaceInferer,
1212-
ops::SigmoidTribleGradMaker<paddle::framework::OpDesc>,
1213-
ops::SigmoidTribleGradMaker<paddle::imperative::OpBase>);
1216+
ops::SigmoidTripleGradMaker<paddle::framework::OpDesc>,
1217+
ops::SigmoidTripleGradMaker<paddle::imperative::OpBase>);
12141218

1215-
// 4. Register Sigmoid TribleGrad Operator
1216-
REGISTER_OPERATOR(sigmoid_trible_grad,
1217-
ops::ActivationOpTribleGrad<
1218-
ops::SigmoidTribleGradFunctor<float>::FwdDeps()>,
1219-
ops::ActivationTribleGradOpInplaceInferer);
1219+
// 4. Register Sigmoid TripleGrad Operator
1220+
REGISTER_OPERATOR(sigmoid_triple_grad,
1221+
ops::ActivationOpTripleGrad<
1222+
ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
1223+
ops::ActivationTripleGradOpInplaceInferer);
12201224

12211225
// Register Sigmoid/GradSigmoid Kernels
12221226
REGISTER_ACTIVATION_CPU_KERNEL(sigmoid, Sigmoid, SigmoidFunctor,
@@ -1232,15 +1236,15 @@ REGISTER_OP_CPU_KERNEL(
12321236
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
12331237
ops::SigmoidGradGradFunctor<plat::float16>>);
12341238

1235-
// Register TribleGrad Kernel
1239+
// Register TripleGrad Kernel
12361240
REGISTER_OP_CPU_KERNEL(
1237-
sigmoid_trible_grad,
1238-
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1239-
ops::SigmoidTribleGradFunctor<float>>,
1240-
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1241-
ops::SigmoidTribleGradFunctor<double>>,
1242-
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1243-
ops::SigmoidTribleGradFunctor<plat::float16>>);
1241+
sigmoid_triple_grad,
1242+
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
1243+
ops::SigmoidTripleGradFunctor<float>>,
1244+
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
1245+
ops::SigmoidTripleGradFunctor<double>>,
1246+
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
1247+
ops::SigmoidTripleGradFunctor<plat::float16>>);
12441248

12451249
/* ========================================================================== */
12461250

paddle/fluid/operators/activation_op.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,13 +1400,13 @@ REGISTER_OP_CUDA_KERNEL(
14001400
ops::SigmoidGradGradFunctor<plat::float16>>);
14011401

14021402
REGISTER_OP_CUDA_KERNEL(
1403-
sigmoid_trible_grad,
1404-
ops::SigmoidTribleGradKernel<paddle::platform::CUDADeviceContext,
1405-
ops::SigmoidTribleGradFunctor<float>>,
1406-
ops::SigmoidTribleGradKernel<paddle::platform::CUDADeviceContext,
1407-
ops::SigmoidTribleGradFunctor<double>>,
1408-
ops::SigmoidTribleGradKernel<plat::CUDADeviceContext,
1409-
ops::SigmoidTribleGradFunctor<plat::float16>>);
1403+
sigmoid_triple_grad,
1404+
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
1405+
ops::SigmoidTripleGradFunctor<float>>,
1406+
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
1407+
ops::SigmoidTripleGradFunctor<double>>,
1408+
ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
1409+
ops::SigmoidTripleGradFunctor<plat::float16>>);
14101410
/* ========================================================================== */
14111411

14121412
/* =========================== tanh register ============================ */

paddle/fluid/operators/activation_op.h

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,29 @@ limitations under the License. */
2424
#define _USE_MATH_DEFINES
2525
#endif
2626

27+
#include <type_traits>
2728
#include "paddle/fluid/framework/eigen.h"
2829
#include "paddle/fluid/framework/op_registry.h"
2930
#include "paddle/fluid/framework/tensor_util.h"
3031
#include "paddle/fluid/operators/math/blas.h"
3132
#include "paddle/fluid/platform/enforce.h"
3233
#include "paddle/fluid/platform/float16.h"
33-
3434
#ifdef PADDLE_WITH_MKLDNN
3535
#include "paddle/fluid/platform/mkldnn_helper.h"
3636
#endif
3737

3838
namespace paddle {
3939
namespace operators {
4040

41-
template <typename T>
42-
void PrintTensor(const framework::Tensor& src,
43-
const framework::ExecutionContext& ctx) {
44-
std::vector<T> vec(src.numel());
45-
TensorToVector(src, ctx.device_context(), &vec);
46-
for (int i = 0; i < static_cast<int>(vec.size()); ++i) {
47-
VLOG(3) << "vec[" << i << "] : " << vec[i];
48-
}
49-
}
41+
// template <typename T>
42+
// void PrintTensor(const framework::Tensor& src,
43+
// const framework::ExecutionContext& ctx) {
44+
// std::vector<T> vec(src.numel());
45+
// TensorToVector(src, ctx.device_context(), &vec);
46+
// for (int i = 0; i < static_cast<int>(vec.size()); ++i) {
47+
// VLOG(3) << "vec[" << i << "] : " << vec[i];
48+
// }
49+
// }
5050

5151
using framework::To32BitIndex;
5252

@@ -288,8 +288,10 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
288288
const framework::Tensor* ddX, const framework::Tensor* dOut,
289289
framework::Tensor* dOutNew, framework::Tensor* ddOut) const {
290290
VLOG(3) << "=========== in SigmoidGradGradFunctor =========";
291-
auto* d = dev.eigen_device();
292291

292+
VLOG(3) << " === is double " << std::is_same<T, double>::value;
293+
VLOG(3) << " === is float " << std::is_same<T, float>::value;
294+
auto* d = dev.eigen_device();
293295
auto ddx = framework::EigenVector<T>::Flatten(
294296
GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidGradGrad"));
295297
auto out = framework::EigenVector<T>::Flatten(
@@ -315,7 +317,7 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
315317
/*
316318
Out
317319
DOut D_Dout
318-
DDx -> SigmoidTribleGrad -> D_DDx
320+
DDx -> SigmoidTripleGrad -> D_DDx
319321
D_DDout d_OutNew
320322
D_Dout_new
321323
@@ -330,45 +332,45 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
330332
// D_OutNew, D_DOut, D_DDx // output
331333
*/
332334
template <typename T>
333-
struct SigmoidTribleGradFunctor : public BaseActivationFunctor<T> {
335+
struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
334336
template <typename Device>
335337
void operator()(const Device& dev, const framework::Tensor* Out,
336338
const framework::Tensor* ddX, const framework::Tensor* dOut,
337339
const framework::Tensor* d_DDOut,
338340
const framework::Tensor* d_dOut_New,
339341
framework::Tensor* d_d_Out, framework::Tensor* d_Out_New,
340342
framework::Tensor* d_DDx) const {
341-
VLOG(3) << "=========== in SigmoidTribleGradFunctor =========";
343+
VLOG(3) << "=========== in SigmoidTripleGradFunctor =========";
342344
auto* d = dev.eigen_device();
343345
auto ddx = framework::EigenVector<T>::Flatten(
344-
GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTribleGrad"));
346+
GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTripleGrad"));
345347
auto out = framework::EigenVector<T>::Flatten(
346-
GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTribleGrad"));
348+
GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTripleGrad"));
347349
auto dout = framework::EigenVector<T>::Flatten(
348-
GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTribleGrad"));
350+
GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTripleGrad"));
349351
auto d_ddOut = framework::EigenVector<T>::Flatten(
350-
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTribleGrad"));
352+
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad"));
351353
auto d_dOutNew = framework::EigenVector<T>::Flatten(GET_DATA_SAFELY(
352-
d_dOut_New, "Input", "D_DOut_New", "SigmoidTribleGrad"));
354+
d_dOut_New, "Input", "D_DOut_New", "SigmoidTripleGrad"));
353355

354356
if (d_Out_New) {
355357
VLOG(3) << " ========== in if (d_Out_New) { ==========";
356358
auto d_OutNew = framework::EigenVector<T>::Flatten(GET_DATA_SAFELY(
357-
d_Out_New, "Output", "D_OutNew", "SigmoidTribleGrad"));
359+
d_Out_New, "Output", "D_OutNew", "SigmoidTripleGrad"));
358360
d_OutNew.device(*d) = (ddx - static_cast<T>(2) * out * ddx) * d_ddOut -
359361
static_cast<T>(2) * dout * ddx * d_dOutNew;
360362
}
361363
if (d_d_Out) {
362364
VLOG(3) << " ========== in if (d_d_Out) { ==========";
363365
auto d_dOut = framework::EigenVector<T>::Flatten(
364-
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTribleGrad"));
366+
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTripleGrad"));
365367
d_dOut.device(*d) =
366368
(static_cast<T>(1) - static_cast<T>(2) * out) * ddx * d_dOutNew;
367369
}
368370
if (d_DDx) {
369371
VLOG(3) << " ========== in if (d_DDx) { ==========";
370372
auto d_ddx = framework::EigenVector<T>::Flatten(
371-
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTribleGrad"));
373+
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTripleGrad"));
372374
d_ddx.device(*d) =
373375
(static_cast<T>(1) - out) * out * d_ddOut +
374376
(static_cast<T>(1) - static_cast<T>(2) * out) * dout * d_ddOut;
@@ -1941,6 +1943,9 @@ class SigmoidDoubleGradKernel
19411943
Out = ddX = dOut = nullptr;
19421944
dOutNew = ddOut = nullptr;
19431945

1946+
VLOG(3) << " === is double " << std::is_same<T, double>::value;
1947+
VLOG(3) << " === is float " << std::is_same<T, float>::value;
1948+
19441949
// extract ddx(input) and out(input)
19451950
ddX = ctx.Input<framework::Tensor>("DDX");
19461951
Out = ctx.Input<framework::Tensor>("Out");
@@ -1963,17 +1968,17 @@ class SigmoidDoubleGradKernel
19631968
"Cannot get input Variable dOut, variable name = %s",
19641969
ctx.InputName("DOut")));
19651970

1966-
VLOG(3) << "================ dOut ===========";
1967-
PrintTensor<T>(*dOut, ctx);
1968-
VLOG(3) << "================ dOut ===========";
1971+
// VLOG(3) << "================ dOut ===========";
1972+
// PrintTensor<T>(*dOut, ctx);
1973+
// VLOG(3) << "================ dOut ===========";
19691974

1970-
VLOG(3) << "================ ddX ===========";
1971-
PrintTensor<T>(*ddX, ctx);
1972-
VLOG(3) << "================ ddX ===========";
1975+
// VLOG(3) << "================ ddX ===========";
1976+
// PrintTensor<T>(*ddX, ctx);
1977+
// VLOG(3) << "================ ddX ===========";
19731978

1974-
VLOG(3) << "================ Out ===========";
1975-
PrintTensor<T>(*Out, ctx);
1976-
VLOG(3) << "================ Out ===========";
1979+
// VLOG(3) << "================ Out ===========";
1980+
// PrintTensor<T>(*Out, ctx);
1981+
// VLOG(3) << "================ Out ===========";
19771982
// set output dout_new
19781983
dOutNew = ctx.Output<framework::Tensor>("DOutNew");
19791984

@@ -1988,12 +1993,12 @@ class SigmoidDoubleGradKernel
19881993
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
19891994
// D_OutNew, D_DOut, D_DDx // output
19901995
template <typename DeviceContext, typename Functor>
1991-
class SigmoidTribleGradKernel
1996+
class SigmoidTripleGradKernel
19921997
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
19931998
public:
19941999
using T = typename Functor::ELEMENT_TYPE;
19952000
void Compute(const framework::ExecutionContext& ctx) const override {
1996-
VLOG(3) << "=========== in SigmoidTribleGradKernel =========";
2001+
VLOG(3) << "=========== in SigmoidTripleGradKernel =========";
19972002
const framework::Tensor *Out, *ddX, *dOut, *d_ddOut, *d_dOutNew;
19982003
framework::Tensor *d_OutNew, *d_dOut, *d_ddx;
19992004
Out = ddX = dOut = d_ddOut = d_dOutNew = nullptr;
@@ -2007,25 +2012,25 @@ class SigmoidTribleGradKernel
20072012
d_ddOut = ctx.Input<framework::Tensor>("D_DDOut");
20082013
d_dOutNew = ctx.Input<framework::Tensor>("D_DOut_New");
20092014

2010-
VLOG(3) << "================ ddx ===========";
2011-
PrintTensor<T>(*ddX, ctx);
2012-
VLOG(3) << "================ ddx ===========";
2015+
// VLOG(3) << "================ ddx ===========";
2016+
// PrintTensor<T>(*ddX, ctx);
2017+
// VLOG(3) << "================ ddx ===========";
20132018

2014-
VLOG(3) << "================ Out ===========";
2015-
PrintTensor<T>(*Out, ctx);
2016-
VLOG(3) << "================ Out ===========";
2019+
// VLOG(3) << "================ Out ===========";
2020+
// PrintTensor<T>(*Out, ctx);
2021+
// VLOG(3) << "================ Out ===========";
20172022

2018-
VLOG(3) << "================ dOut ===========";
2019-
PrintTensor<T>(*dOut, ctx);
2020-
VLOG(3) << "================ dOut ===========";
2023+
// VLOG(3) << "================ dOut ===========";
2024+
// PrintTensor<T>(*dOut, ctx);
2025+
// VLOG(3) << "================ dOut ===========";
20212026

2022-
VLOG(3) << "================ d_ddOut ===========";
2023-
PrintTensor<T>(*d_ddOut, ctx);
2024-
VLOG(3) << "================ d_ddOut ===========";
2027+
// VLOG(3) << "================ d_ddOut ===========";
2028+
// PrintTensor<T>(*d_ddOut, ctx);
2029+
// VLOG(3) << "================ d_ddOut ===========";
20252030

2026-
VLOG(3) << "================ d_dOutNew ===========";
2027-
PrintTensor<T>(*d_dOutNew, ctx);
2028-
VLOG(3) << "================ d_dOutNew ===========";
2031+
// VLOG(3) << "================ d_dOutNew ===========";
2032+
// PrintTensor<T>(*d_dOutNew, ctx);
2033+
// VLOG(3) << "================ d_dOutNew ===========";
20292034

20302035
PADDLE_ENFORCE_NOT_NULL(
20312036
ddX, platform::errors::NotFound(

0 commit comments

Comments
 (0)