Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,57 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
}
};

//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", Input("Out"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DOut", InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
Expand Down Expand Up @@ -632,3 +679,23 @@ namespace ops = paddle::operators;

FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);

REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(relu_grad_grad, ops::ActivationOpDoubleGrad);

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
11 changes: 11 additions & 0 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,14 @@ namespace plat = paddle::platform;
ops::grad_functor<plat::float16>>);

FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);

REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
121 changes: 120 additions & 1 deletion paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1198,14 +1198,133 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

/*
* in arguments: x, out, ddx
* out arguments: ddout, dout, dx
*/
template <ActBwdOpFwdDeps kDepValue>
inline void ExtractActivationDoubleGradTensor(
const framework::ExecutionContext& ctx, const framework::Tensor** X,
const framework::Tensor** Out, const framework::Tensor** ddX,
framework::Tensor** dX, framework::Tensor** dOut,
framework::Tensor** ddOut) {
auto out_var = ctx.InputVar("Out");
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
auto do_var = ctx.OutputVar("DOut");
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
ctx.op().Input("Out"));
PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable %s, variable name = %s", "DDX",
ctx.op().Input("DDX"));
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
*Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
*ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var);
if (ddo_var) {
*ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
ddo_var);
}
if (do_var) {
*dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
do_var);
}
} else {
*Out = ctx.Input<framework::Tensor>("Out");
*ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
*ddOut = ctx.Output<framework::Tensor>("DDOut");
}
if (do_var) {
*dOut = ctx.Output<framework::Tensor>("DOut");
}
}
PADDLE_ENFORCE(*ddX != nullptr,
"Cannot get output tensor %s, variable name = %s", "DDX",
ctx.op().Output("DDX"));

if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s",
ctx.op().Input("X"));
auto dx_var = ctx.OutputVar("DX");
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
if (dx_var) {
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
dx_var);
}
} else {
*X = ctx.Input<framework::Tensor>("X");
if (dx_var) {
*dX = ctx.Output<framework::Tensor>("DX");
}
}
} else {
VLOG(10) << " Inplace activation of Op : " << ctx.op().Type();
*X = *ddX;
}
}

template <typename DeviceContext, typename Functor>
class ActivationDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *Out, *ddX;
X = Out = ddX = nullptr;
framework::Tensor *ddOut, *dOut, *dX;
ddOut = dOut = dX = nullptr;

ExtractActivationDoubleGradTensor<Functor::FwdDeps()>(ctx, &X, &Out, &ddX,
&dX, &dOut, &ddOut);

if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
if (dOut) dOut->mutable_data<T>(ctx.GetPlace());
if (dX) dX->mutable_data<T>(Out->dims(), ctx.GetPlace());

auto& place = ctx.template device_context<DeviceContext>();

Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, Out, ddX, ddOut, dOut, dX);
}
};

template <typename T>
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* Out, const framework::Tensor* ddX,
framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
}
if (dOut) {
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dout.device(*d) = dout.constant(static_cast<T>(0));
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

} // namespace operators
} // namespace paddle

#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, Relu, ReluFunctor, ReluGradFunctor); \
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
if inputs:
for op in op_path:
for name in op.desc.input_arg_names():
if name not in input_names:
if name not in input_names and block.vars[name].stop_gradient:
no_grad_set.add(name)

return op_path
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com
list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957

list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorator_helper) # decorator_helper is a helper python file, not a test
if(APPLE)
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_desc_clone)
Expand Down
Loading