From 823b0e9553b96748c1196ef5f8c0e987d9e97f11 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Mon, 19 Apr 2021 07:22:32 +0000 Subject: [PATCH 01/11] rebase --- paddle/fluid/operators/activation_op.cu | 762 ++++++++++++++---------- 1 file changed, 462 insertions(+), 300 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 781a97c1ffcc17..eede79ec924858 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -10,337 +10,435 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using float16 = paddle::platform::float16; +template +struct BaseCudaActiveFunctor { + using ELEMENT_TYPE = T; + using AttrPair = std::vector>; + AttrPair GetAttrs() { return AttrPair(); } +}; +// For forward, args[0] means the input x; +// For backward, args[0] means the input dout, args[1] means the input x or out, +// which depends on the FwdDeps; +/********************Relu Begin********************/ template -struct CudaVecType { - using type = T; - static constexpr int vecsize = 1; +struct CudaReluFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > zero ? args[0] : zero; + } }; -template <> -struct CudaVecType { - using type = __half2; - static constexpr int vecsize = 2; +template +struct CudaReluGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > zero ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +/********************Relu End********************/ + +/********************LeakyRelu Begin********************/ +template +struct CudaLeakyReluFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float alpha; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } -template <> -struct CudaVecType { - using type = float4; - static constexpr int vecsize = 4; + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > zero ? args[0] : static_cast(alpha) * args[0]; + } }; template -class BaseGPUFunctor { - public: - using ELEMENT_TYPE = T; +struct CudaLeakyReluGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float alpha; - using AttrPair = std::vector>; + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } - AttrPair GetAttrs() { return AttrPair(); } + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > zero ? args[0] : static_cast(alpha) * args[0]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +/********************LeakyRelu End********************/ -/* ========================================================================== */ +/********************Sigmoid Begin********************/ +template +struct CudaSigmoidFunctor : public BaseCudaActiveFunctor { + // CT means Compute Type + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(one / (one + exp(-x))); + } +}; -/* =========================== relu forward ============================ */ template -class ReluGPUFunctor : public BaseGPUFunctor { - private: - T zero_; +struct CudaSigmoidGradFunctor : public BaseCudaActiveFunctor { + T one = static_cast(1.0f); - public: - ReluGPUFunctor() { zero_ = static_cast(0.0f); } - - // for relu forward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in) { - // relu forward : out = max(x, 0) - return in > zero_ ? in : zero_; - } - - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T in) { - // relu forward : out = max(x, 0) - return in > zero_ ? in : zero_; - } -}; - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type in) { - // relu forward : out = max(in, 0) - return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y), - (in.z > zero_) * (in.z), (in.w > zero_) * (in.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type in) { -// relu forward : out = max(in, 0) -#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(in, kzero), in); -#else - const float2 xx = __half22float2(in); - return __floats2half2_rn((xx.x > 0.0f) * static_cast(xx.x), - (xx.y > 0.0f) * static_cast(xx.y)); -#endif -} -/* ========================================================================== */ + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[1] * (one - args[1]); + } -/* =========================== relu backward ============================ - */ + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Sigmoid End********************/ +/********************LogSigmoid Begin********************/ template -class ReluGradGPUFunctor : public BaseGPUFunctor { - private: - T zero_; +struct CudaLogSigmoidFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT temp = x > zero ? zero : -x; + return T(-temp - log(exp(-temp) + exp(-x - temp))); + } +}; - public: - ReluGradGPUFunctor() { zero_ = static_cast(0.0f); } +template +struct CudaLogSigmoidGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + CT temp = x > zero ? zero : -x; + return T(dout * (exp(-x - temp) / (exp(-temp) + exp(-x - temp)))); + } - // for relu backward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type out, - const typename CudaVecType::type dout) { - return out > zero_ ? dout : zero_; + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************LogSigmoid End********************/ + +/********************Atan Begin********************/ +template +struct CudaAtanFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(atan(x)); } +}; - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) { - // relu backward : dx = out > 0 ? dout : 0 - return out > zero_ ? dout : zero_; +template +struct CudaAtanGradFunctor : public BaseCudaActiveFunctor { + T one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * one / (one + args[1] * args[1]); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +/********************Atan End********************/ -template <> -__device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type out, - const CudaVecType::type dout) { - // relu backward : dx = out > 0 ? dout : 0; - return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y), - (out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type out, - const CudaVecType::type dout) { -// relu backward : dx = out > 0 ? dout : 0; -#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(out, kzero), dout); -#else - const float2 xx = __half22float2(out); - const float2 yy = __half22float2(dout); - return __floats2half2_rn((xx.x > 0.0f) * static_cast(yy.x), - (xx.y > 0.0f) * static_cast(yy.y)); -#endif -} - -/* ========================================================================== */ -/* ======================== leaky relu forward ======================== - */ +/********************SoftShrink Begin********************/ template -class LeakyReluGPUFunctor : public BaseGPUFunctor { - private: - T zero_; - float alpha_; +struct CudaSoftShrinkFunctor : public BaseCudaActiveFunctor { + float lambda; - public: - LeakyReluGPUFunctor() { zero_ = static_cast(0.0f); } + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T lambdaT = static_cast(lambda); + T temp1 = static_cast(args[0] > lambdaT); + T temp2 = static_cast(args[0] < -lambdaT); + return temp1 * (args[0] - lambdaT) + temp2 * (args[0] + lambdaT); + } +}; + +template +struct CudaSoftShrinkGradFunctor : public BaseCudaActiveFunctor { + float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha_}}; - } - // leakyrelu forward : out = x > 0 ? x : x * alpha - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in) { - return in > zero_ ? in : static_cast(alpha_) * in; - } - - __device__ __forceinline__ T ComputeRemainder(const T in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - return in > zero_ ? in : static_cast(alpha_) * in; - } -}; - -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_, - (in.y > zero_) ? (in.y) : (in.y) * alpha_, - (in.z > zero_) ? (in.z) : (in.z) * alpha_, - (in.w > zero_) ? (in.w) : (in.w) * alpha_); -} - -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { - // leakyrelu forward : out = x > 0 ? x : x * alpha - const float2 xx = __half22float2(in); - return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_, - (xx.y > 0.0f) ? xx.y : xx.y * alpha_); -} -/* ========================================================================== */ + return {{"lambda", &lambda}}; + } -/* =========================== leaky relu backward ======================= - */ + __device__ __forceinline__ T operator()(const T* args) const { + T lambdaT = static_cast(lambda); + T temp1 = static_cast(args[1] > lambdaT); + T temp2 = static_cast(args[1] < -lambdaT); + return args[0] * static_cast(temp1 + temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************SoftShrink End********************/ + +/********************Ceil Begin********************/ template -class LeakyReluGradGPUFunctor : public BaseGPUFunctor { - private: - T zero_; - float alpha_; +struct CudaCeilFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(ceil(x)); + } +}; +/********************Ceil End********************/ - public: - LeakyReluGradGPUFunctor() { zero_ = static_cast(0.0f); } +/********************Floor Begin********************/ +template +struct CudaFloorFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(floor(x)); + } +}; +/********************Floor End********************/ - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha_}}; +/********************Round Begin********************/ +template +struct CudaRoundFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(round(x)); } +}; +/********************Floor End********************/ - // for leaky relu backward when T is double - __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type in, - const typename CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return in > zero_ ? dout : static_cast(alpha_) * dout; +/********************Zero Begin********************/ +template +struct CudaZeroGradFunctor : public BaseCudaActiveFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return static_cast(0.0f); } - // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T in, const T dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return in > zero_ ? dout : static_cast(alpha_) * dout; + static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } +}; +/********************Zero End********************/ + +/********************Cos Begin********************/ +template +struct CudaCosFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(cos(x)); + } +}; + +template +struct CudaCosGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +/********************Cos End********************/ + +/********************Sin Begin********************/ +template +struct CudaSinFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(sin(x)); + } +}; -template <> -__device__ __forceinline__ CudaVecType::type -LeakyReluGradGPUFunctor::Compute(const CudaVecType::type in, - const CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x), - (in.y > zero_) ? (dout.y) : alpha_ * (dout.y), - (in.z > zero_) ? (dout.z) : alpha_ * (dout.z), - (in.w > zero_) ? (dout.w) : alpha_ * (dout.w)); -} - -template <> -__device__ __forceinline__ CudaVecType::type LeakyReluGradGPUFunctor< - float16>::Compute(const CudaVecType::type in, - const CudaVecType::type dout) { - // leakyrelu backward : dx = x > 0 ? dout : alpha * dout - const float2 xx = __half22float2(in); - const float2 yy = __half22float2(dout); - return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x, - (xx.y > 0.0f) ? yy.y : alpha_ * yy.y); -} +template +struct CudaSinGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout * cos(x)); + } -/* ========================================================================== */ + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Sin End********************/ + +/********************Tan Begin********************/ +template +struct CudaTanFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(tan(x)); + } +}; + +template +struct CudaTanGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout / (cos(x) * cos(x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Tan End********************/ + +/********************Asin Begin********************/ +template +struct CudaAsinFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(asin(x)); + } +}; + +template +struct CudaAsinGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout * one / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Asin End********************/ + +/********************Acos Begin********************/ +template +struct CudaAcosFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(acos(x)); + } +}; + +template +struct CudaAcosGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(-dout * one / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Acos End********************/ + +/********************Cosh Begin********************/ +template +struct CudaCoshFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(cosh(x)); + } +}; + +template +struct CudaCoshGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout * sinh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Cosh End********************/ + +/********************Sinh Begin********************/ +template +struct CudaSinhFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(sinh(x)); + } +}; + +template +struct CudaSinhGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout * cosh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Sinh End********************/ + +/********************Reciprocal Begin********************/ +template +struct CudaReciprocalFunctor : public BaseCudaActiveFunctor { + T one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + return one / args[0]; + } +}; -template -__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout, - T* dx, int num, Functor functor) { - using VecType = typename CudaVecType::type; - constexpr int vecsize = CudaVecType::vecsize; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - int loop = num / vecsize; - int tail = num % vecsize; - const VecType* in_forward = reinterpret_cast(forward_data); - const VecType* in_dout = reinterpret_cast(dout); - VecType* out = reinterpret_cast(dx); - VecType forward_vec, dout_vec; - T in_data, dout_data; - for (int i = idx; i < loop; i += stride) { -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - forward_vec = __ldg(in_forward + i); - dout_vec = __ldg(in_dout + i); -#else - forward_vec = in_forward[i]; - dout_vec = in_dout[i]; -#endif - out[i] = functor.Compute(forward_vec, dout_vec); - } - - while (idx == loop && tail) { - in_data = forward_data[num - tail]; - dout_data = dout[num - tail]; - dx[num - tail] = functor.ComputeRemainder(in_data, dout_data); - --tail; - } -} - -template -__global__ void ActivationkernelVec(const T* src, T* dst, int num, - Functor functor) { - constexpr int vecsize = CudaVecType::vecsize; - using VecType = typename CudaVecType::type; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - int loop = num / vecsize; - int tail = num % vecsize; - const VecType* in = reinterpret_cast(src); - VecType* out = reinterpret_cast(dst); - VecType x_vec; - for (int i = idx; i < loop; i += stride) { -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - x_vec = __ldg(in + i); -#else - x_vec = in[i]; -#endif - out[i] = functor.Compute(x_vec); - } - - while (idx == loop && tail) { - dst[num - tail] = functor.ComputeRemainder(src[num - tail]); - --tail; - } -} +template +struct CudaReciprocalGradFunctor : public BaseCudaActiveFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return -args[0] * args[1] * args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Reciprocal End********************/ template class ActivationGPUKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { - const framework::Tensor* in_x = nullptr; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* x = nullptr; framework::Tensor* out = nullptr; - ExtractActivationTensor(context, &in_x, &out); - auto& dev_ctx = context.template device_context(); - - int num = in_x->numel(); - const T* input_data = in_x->data(); - T* output_data = out->mutable_data(dev_ctx.GetPlace(), - static_cast(num * sizeof(T))); - - int block = 512; -#ifdef __HIPCC__ - block = 256; -#endif - Functor functor; + ExtractActivationTensor(ctx, &x, &out); + out->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); + *attr.second = ctx.Attr(attr.first); } - constexpr int vecsize = CudaVecType::vecsize; - int grid = max((num / vecsize + block - 1) / block, 1); - auto stream = context.cuda_device_context().stream(); - ActivationkernelVec<<>>( - input_data, output_data, num, functor); + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, + functor); } }; @@ -349,43 +447,38 @@ class ActivationGradGPUKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *x, *out, *d_out; framework::Tensor* d_x = nullptr; x = out = d_out = nullptr; - ExtractActivationGradTensor(context, &x, &out, &d_out, + ExtractActivationGradTensor(ctx, &x, &out, &d_out, &d_x); - int numel = d_out->numel(); - auto& dev_ctx = context.template device_context(); - auto* dx_data = d_x->mutable_data( - dev_ctx.GetPlace(), static_cast(numel * sizeof(T))); - auto* dout_data = d_out->data(); + d_x->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + auto functor = Functor(); + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + + std::vector ins = {d_out}; + std::vector outs = {d_x}; - auto* forward_data = dout_data; if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { // Only need forward output Out - forward_data = out->data(); + ins.push_back(out); + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(kDepX)) { // Only need forward input X - forward_data = x->data(); - } - - int block = 512; -#ifdef __HIPCC__ - block = 256; -#endif - - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); + ins.push_back(x); + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); + } else { + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } - constexpr int vecsize = CudaVecType::vecsize; - int grid = max((numel / vecsize + block - 1) / block, 1); - auto stream = context.cuda_device_context().stream(); - ActivationGradKernelVec<<>>( - forward_data, dout_data, dx_data, numel, functor); } }; @@ -410,7 +503,6 @@ namespace plat = paddle::platform; ops::grad_functor>, \ ops::ActivationGradKernel>); -FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ grad_functor) \ @@ -430,8 +522,8 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); ops::grad_functor>); /* ======================== leaky relu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor, - LeakyReluGradGPUFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, + CudaLeakyReluGradFunctor); REGISTER_OP_CUDA_KERNEL( leaky_relu_grad_grad, @@ -456,7 +548,8 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, CudaReluFunctor, + CudaReluGradFunctor); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, @@ -594,3 +687,72 @@ REGISTER_OP_CUDA_KERNEL( ops::LogDoubleGradKernel>); /* ========================================================================== */ +REGISTER_ACTIVATION_GPU_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, + CudaSigmoidGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, + CudaLogSigmoidGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(atan, Atan, CudaAtanFunctor, + CudaAtanGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, + CudaSoftShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(ceil, Ceil, CudaCeilFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(floor, Floor, CudaFloorFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(acos, Acos, CudaAcosFunctor, + CudaAcosGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(asin, Asin, CudaAsinFunctor, + CudaAsinGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(sinh, Sinh, CudaSinhFunctor, + CudaSinhGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(cosh, Cosh, CudaCoshFunctor, + CudaCoshGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(round, Round, CudaRoundFunctor, + CudaZeroGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, SigmoidFunctor, +// SigmoidGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, LogSigmoidFunctor, +// LogSigmoidGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, AtanFunctor, AtanGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, SoftShrinkFunctor, +// SoftShrinkGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CeilFunctor, ZeroGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, FloorFunctor, +// ZeroGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CosFunctor, CosGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, TanFunctor, TanGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, AcosFunctor, AcosGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, SinFunctor, SinGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, AsinFunctor, AsinGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, SinhFunctor, SinhGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CoshFunctor, CoshGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, RoundFunctor, ZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, ReciprocalFunctor, + ReciprocalGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, + SoftReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, SoftplusFunctor, + SoftplusGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, SoftsignFunctor, + SoftsignGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, + TanhShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, + HardShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, + HardSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(thresholded_relu, ThresholdedRelu, + ThresholdedReluFunctor, + ThresholdedReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_swish, HardSwish, HardSwishFunctor, + HardSwishGradFunctor); From ec22be6e74c659e4ec7cf857813343534a90cd66 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Thu, 22 Apr 2021 13:15:42 +0000 Subject: [PATCH 02/11] add 12 op --- paddle/fluid/operators/activation_op.cu | 444 ++++++++++++++++++++++-- 1 file changed, 415 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index eede79ec924858..e62aebce7cbf6c 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -150,7 +150,7 @@ template struct CudaAtanGradFunctor : public BaseCudaActiveFunctor { T one = static_cast(1.0f); __device__ __forceinline__ T operator()(const T* args) const { - return args[0] * one / (one + args[1] * args[1]); + return args[0] / (one + args[1] * args[1]); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -167,10 +167,11 @@ struct CudaSoftShrinkFunctor : public BaseCudaActiveFunctor { } __device__ __forceinline__ T operator()(const T* args) const { - T lambdaT = static_cast(lambda); - T temp1 = static_cast(args[0] > lambdaT); - T temp2 = static_cast(args[0] < -lambdaT); - return temp1 * (args[0] - lambdaT) + temp2 * (args[0] + lambdaT); + T x = args[0]; + T l = static_cast(lambda); + T temp1 = static_cast(x > l); + T temp2 = static_cast(x < -l); + return temp1 * (x - l) + temp2 * (x + l); } }; @@ -183,10 +184,11 @@ struct CudaSoftShrinkGradFunctor : public BaseCudaActiveFunctor { } __device__ __forceinline__ T operator()(const T* args) const { - T lambdaT = static_cast(lambda); - T temp1 = static_cast(args[1] > lambdaT); - T temp2 = static_cast(args[1] < -lambdaT); - return args[0] * static_cast(temp1 + temp2); + T x = args[1]; + T l = static_cast(lambda); + T temp1 = static_cast(x > l); + T temp2 = static_cast(x < -l); + return args[0] * (temp1 + temp2); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -323,7 +325,7 @@ struct CudaAsinGradFunctor : public BaseCudaActiveFunctor { __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); - return T(dout * one / sqrt(one - x * x)); + return T(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -347,7 +349,7 @@ struct CudaAcosGradFunctor : public BaseCudaActiveFunctor { __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); - return T(-dout * one / sqrt(one - x * x)); + return T(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -419,6 +421,361 @@ struct CudaReciprocalGradFunctor : public BaseCudaActiveFunctor { }; /********************Reciprocal End********************/ +/********************Log1p Begin********************/ +template +struct CudaLog1pFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(log(one + x)); + } +}; + +template +struct CudaLog1pGradFunctor : public BaseCudaActiveFunctor { + T one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (one + args[1]); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Log1p End********************/ + +/********************Log2 Begin********************/ +template +struct CudaLog2Functor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(log2(x)); + } +}; + +template +struct CudaLog2GradFunctor : public BaseCudaActiveFunctor { + T log_two = static_cast(log(2)); + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (args[1] * log_two); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Log2 End********************/ + +/********************Log10 Begin********************/ +template +struct CudaLog10Functor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(log10(x)); + } +}; + +template +struct CudaLog10GradFunctor : public BaseCudaActiveFunctor { + T log_ten = static_cast(log(10)); + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / (args[1] * log_ten); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Log10 End********************/ + +/********************BRelu Begin********************/ +template +struct CudaBReluFunctor : public BaseCudaActiveFunctor { + float t_min; + float t_max; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t_min_cast = static_cast(t_min); + T t_max_cast = static_cast(t_max); + return (x > t_min_cast && x < t_max_cast) + ? x + : (x <= t_min_cast ? t_min_cast : t_max_cast); + } +}; + +template +struct CudaBReluGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float t_min; + float t_max; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"t_min", &t_min}, {"t_max", &t_max}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T dout = args[0]; + T x = args[1]; + T t_min_cast = static_cast(t_min); + T t_max_cast = static_cast(t_max); + return (x <= t_min_cast || x >= t_max_cast) ? zero : dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************BRelu End********************/ + +/********************SoftRelu Begin********************/ +template +struct CudaSoftReluFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT t = static_cast(threshold); + CT temp = (x > -t && x < t) ? x : (x <= -t ? -t : t); + return T(log(one + exp(temp))); + } +}; + +template +struct CudaSoftReluGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT out = static_cast(args[1]); + CT t = static_cast(threshold); + return (out <= -t || out >= t) ? static_cast(0.0f) + : T(dout * (one - exp(-out))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************SoftRelu End********************/ + +/********************STanh Begin********************/ +template +struct CudaSTanhFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + float scale_a; + float scale_b; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT a = static_cast(scale_a); + CT b = static_cast(scale_b); + return T(b * tanh(a * x)); + } +}; + +template +struct CudaSTanhGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float scale_a; + float scale_b; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + CT a = static_cast(scale_a); + CT b = static_cast(scale_b); + CT temp = tanh(a * x) * tanh(a * x); + return T(dout * a * b * (one - temp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************STanh End********************/ + +/********************Softplus Begin********************/ +template +struct CudaSoftplusFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT b = static_cast(beta); + CT t = static_cast(threshold); + CT x_beta = x * beta; + return T(x_beta > t ? x : log(one + exp(x_beta)) / b); + } +}; + +template +struct CudaSoftplusGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + CT b = static_cast(beta); + CT t = static_cast(threshold); + CT x_beta = x * beta; + return x_beta > t ? args[0] : T(dout / (one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Softplus End********************/ + +/********************Softsign Begin********************/ +template +struct CudaSoftsignFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(x / (one + abs(x))); + } +}; + +template +struct CudaSoftsignGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout / ((one + abs(x)) * (one + abs(x)))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Softsign End********************/ + +/********************Relu6 Begin********************/ +template +struct CudaRelu6Functor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T t = static_cast(threshold); + return args[0] <= zero ? zero : (args[0] < t ? args[0] : t); + } +}; + +template +struct CudaRelu6GradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T t = static_cast(threshold); + return (args[1] > zero && args[1] < t) ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Relu6 End********************/ + +/********************TanhShrink Begin********************/ +template +struct CudaTanhShrinkFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(x - tanh(x)); + } +}; + +template +struct CudaTanhShrinkGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return T(dout * tanh(x) * tanh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************TanhShrink End********************/ + +/********************HardShrink Begin********************/ +template +struct CudaHardShrinkFunctor : public BaseCudaActiveFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t = static_cast(threshold); + T temp1 = static_cast(x > t); + T temp2 = static_cast(x < -t); + return x * (temp1 + temp2); + } +}; + +template +struct CudaHardShrinkGradFunctor : public BaseCudaActiveFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[1]; + T t = static_cast(threshold); + T temp1 = static_cast(x > t); + T temp2 = static_cast(x < -t); + return args[0] * (temp1 + temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************HardShrink End********************/ + template class ActivationGPUKernel : public framework::OpKernel { @@ -712,6 +1069,30 @@ REGISTER_ACTIVATION_GPU_KERNEL(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); REGISTER_ACTIVATION_GPU_KERNEL(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, + CudaReciprocalGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, CudaLog1pFunctor, + CudaLog1pGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, CudaLog2Functor, + CudaLog2GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, CudaLog10Functor, + CudaLog10GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, CudaBReluFunctor, + CudaBReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, CudaSoftReluFunctor, + CudaSoftReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, CudaSTanhFunctor, + CudaSTanhGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, CudaSoftplusFunctor, + CudaSoftplusGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, CudaSoftsignFunctor, + CudaSoftsignGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, CudaRelu6Functor, + CudaRelu6GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, + CudaTanhShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor, + CudaHardShrinkGradFunctor); // REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, SigmoidFunctor, // SigmoidGradFunctor); // REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, LogSigmoidFunctor, @@ -730,24 +1111,29 @@ REGISTER_ACTIVATION_GPU_KERNEL(round, Round, CudaRoundFunctor, // REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, SinhFunctor, SinhGradFunctor); // REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CoshFunctor, CoshGradFunctor); // REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, RoundFunctor, ZeroGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, ReciprocalFunctor, - ReciprocalGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, - SoftReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, SoftplusFunctor, - SoftplusGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, SoftsignFunctor, - SoftsignGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, - TanhShrinkGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, - HardShrinkGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, ReciprocalFunctor, +// ReciprocalGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, Log1pFunctor, +// Log1pGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, Log10Functor, +// Log10GradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, BReluFunctor, +// BReluGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, +// SoftReluGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, STanhFunctor, +// STanhGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, SoftplusFunctor, +// SoftplusGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, SoftsignFunctor, +// SoftsignGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, Relu6Functor, +// Relu6GradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, +// TanhShrinkGradFunctor); +// REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, +// HardShrinkGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); From 59b16b9dcc107548e439834f1d90dcff0561309e Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Fri, 23 Apr 2021 08:07:06 +0000 Subject: [PATCH 03/11] add all activation op --- paddle/fluid/operators/activation_op.cu | 622 +++++++++++++++++------- 1 file changed, 451 insertions(+), 171 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index e62aebce7cbf6c..84454f574c9c30 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -402,6 +402,29 @@ struct CudaSinhGradFunctor : public BaseCudaActiveFunctor { }; /********************Sinh End********************/ +/********************Tanh Begin********************/ +template +struct CudaTanhFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(tanh(x)); + } +}; + +template +struct CudaTanhGradFunctor : public BaseCudaActiveFunctor { + T one = static_cast(1.0f); + __device__ __forceinline__ T operator()(const T* args) const { + T dout = static_cast(args[0]); + T out = static_cast(args[1]); + return dout * (one - out * out); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Tanh End********************/ + /********************Reciprocal Begin********************/ template struct CudaReciprocalFunctor : public BaseCudaActiveFunctor { @@ -421,6 +444,26 @@ struct CudaReciprocalGradFunctor : public BaseCudaActiveFunctor { }; /********************Reciprocal End********************/ +/********************Exp Begin********************/ +template +struct CudaExpFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(exp(x)); + } +}; + +template +struct CudaExpGradFunctor : public BaseCudaActiveFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Exp End********************/ + /********************Log1p Begin********************/ template struct CudaLog1pFunctor : public BaseCudaActiveFunctor { @@ -443,6 +486,26 @@ struct CudaLog1pGradFunctor : public BaseCudaActiveFunctor { }; /********************Log1p End********************/ +/********************Log Begin********************/ +template +struct CudaLogFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(log(x)); + } +}; + +template +struct CudaLogGradFunctor : public BaseCudaActiveFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] / args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Log End********************/ + /********************Log2 Begin********************/ template struct CudaLog2Functor : public BaseCudaActiveFunctor { @@ -776,8 +839,268 @@ struct CudaHardShrinkGradFunctor : public BaseCudaActiveFunctor { }; /********************HardShrink End********************/ +/********************HardSigmoid Begin********************/ +template +struct CudaHardSigmoidFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + float slope; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T temp = args[0] * static_cast(slope) + static_cast(offset); + return (temp > zero && temp < one) ? temp : (temp <= zero ? zero : one); + } +}; + +template +struct CudaHardSigmoidGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + float slope; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T out = args[1]; + return (out > zero && out < one) ? args[0] * static_cast(slope) : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************HardSigmoid End********************/ + +/********************Swish Begin********************/ +template +struct CudaSwishFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float beta; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + CT b = static_cast(beta); + return T(x / (one + exp(-b * x))); + } +}; + +template +struct CudaSwishGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + float beta; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + CT b = static_cast(beta); + CT temp1 = one / (one + exp(-b * x)); + CT out = x * temp1; + CT temp2 = temp1 * (one - b * x); + return T(dout * (b * out + temp2)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Swish End********************/ + +/********************ThresholdedRelu Begin********************/ +template +struct CudaThresholdedReluFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] > static_cast(threshold) ? args[0] : zero; + } +}; + +template +struct CudaThresholdedReluGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + return args[1] > static_cast(threshold) ? args[0] : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************ThresholdedRelu End********************/ + +/********************HardSwish Begin********************/ +template +struct CudaHardSwishFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + float threshold; + float scale; + float offset; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[0]; + T t = static_cast(threshold); + T temp1 = x + static_cast(offset); + T temp2 = (temp1 > zero && temp1 < t) ? temp1 : (temp1 <= zero ? zero : t); + return temp2 * x / static_cast(scale); + } +}; + +template +struct CudaHardSwishGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + T two = static_cast(2.0f); + float threshold; + float scale; + float offset; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + T x = args[1]; + T o = static_cast(offset); + T s = static_cast(scale); + T temp1 = static_cast(x + o > zero); + T temp2 = static_cast(x + o < static_cast(threshold)); + return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************HardSwish End********************/ + +/********************ELU Begin********************/ +template +struct CudaELUFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return x >= zero ? args[0] : T(static_cast(alpha) * (exp(x) - one)); + } +}; + +template +struct CudaELUGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + return x >= zero ? args[0] : T(dout * static_cast(alpha) * exp(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************ELU End********************/ + +/********************Square Begin********************/ +template +struct CudaSquareFunctor : public BaseCudaActiveFunctor { + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[0]; + } +}; + +template +struct CudaSquareGradFunctor : public BaseCudaActiveFunctor { + T two = static_cast(2.0f); + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * two * args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Square End********************/ + +/********************Sqrt Begin********************/ +template +struct CudaSqrtFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(sqrt(x)); + } +}; + +template +struct CudaSqrtGradFunctor : public BaseCudaActiveFunctor { + T one_half = static_cast(0.5f); + __device__ __forceinline__ T operator()(const T* args) const { + return one_half * args[0] / args[1]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Sqrt End********************/ + +/********************Rsqrt Begin********************/ +template +struct CudaRsqrtFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(rsqrt(x)); + } +}; + +template +struct CudaRsqrtGradFunctor : public BaseCudaActiveFunctor { + T minus_one_half = static_cast(-0.5f); + __device__ __forceinline__ T operator()(const T* args) const { + T out = args[1]; + return minus_one_half * args[0] * out * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; +/********************Rsqrt End********************/ + template -class ActivationGPUKernel +class ActivationCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; @@ -800,7 +1123,7 @@ class ActivationGPUKernel }; template -class ActivationGradGPUKernel +class ActivationGradCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; @@ -845,42 +1168,27 @@ class ActivationGradGPUKernel namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ - grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - act_type, \ - ops::ActivationKernel>, \ - ops::ActivationKernel>, \ - ops::ActivationKernel>); \ - REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>); - -#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ - grad_functor) \ +#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ + grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationGPUKernel>, \ - ops::ActivationGPUKernel>, \ - ops::ActivationGPUKernel>); \ + act_type, ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>, \ + ops::ActivationCudaKernel>); \ REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, ops::ActivationGradGPUKernel>, \ - ops::ActivationGradGPUKernel>, \ - ops::ActivationGradGPUKernel>); + act_type##_grad, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>, \ + ops::ActivationGradCudaKernel>); /* ======================== leaky relu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, - CudaLeakyReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor, + CudaLeakyReluGradFunctor); REGISTER_OP_CUDA_KERNEL( leaky_relu_grad_grad, @@ -893,7 +1201,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); + square, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - square_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); + square_grad, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); REGISTER_OP_CUDA_KERNEL( square_grad_grad, @@ -1014,27 +1329,31 @@ REGISTER_OP_CUDA_KERNEL( /* ========================== exp register ============================ */ REGISTER_OP_CUDA_KERNEL( - exp, ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); + exp, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - exp_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); + exp_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== Log register ==================================*/ -REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor); REGISTER_OP_CUDA_KERNEL( log_grad_grad, ops::LogDoubleGradKernel>); /* ========================================================================== */ -REGISTER_ACTIVATION_GPU_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, - CudaSigmoidGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, - CudaLogSigmoidGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(atan, Atan, CudaAtanFunctor, - CudaAtanGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, - CudaSoftShrinkGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(ceil, Ceil, CudaCeilFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(floor, Floor, CudaFloorFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(acos, Acos, CudaAcosFunctor, - CudaAcosGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(asin, Asin, CudaAsinFunctor, - CudaAsinGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(sinh, Sinh, CudaSinhFunctor, - CudaSinhGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(cosh, Cosh, CudaCoshFunctor, - CudaCoshGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(round, Round, CudaRoundFunctor, - CudaZeroGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, - CudaReciprocalGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, CudaLog1pFunctor, - CudaLog1pGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, CudaLog2Functor, - CudaLog2GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, CudaLog10Functor, - CudaLog10GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, CudaBReluFunctor, - CudaBReluGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, CudaSoftReluFunctor, - CudaSoftReluGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, CudaSTanhFunctor, - CudaSTanhGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, CudaSoftplusFunctor, - CudaSoftplusGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, CudaSoftsignFunctor, - CudaSoftsignGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, CudaRelu6Functor, - CudaRelu6GradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, - CudaTanhShrinkGradFunctor); -REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor, - CudaHardShrinkGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, SigmoidFunctor, -// SigmoidGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, LogSigmoidFunctor, -// LogSigmoidGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, AtanFunctor, AtanGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, SoftShrinkFunctor, -// SoftShrinkGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CeilFunctor, ZeroGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, FloorFunctor, -// ZeroGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CosFunctor, CosGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, TanFunctor, TanGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, AcosFunctor, AcosGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, SinFunctor, SinGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, AsinFunctor, AsinGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, SinhFunctor, SinhGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CoshFunctor, CoshGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, RoundFunctor, ZeroGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, ReciprocalFunctor, -// ReciprocalGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, Log1pFunctor, -// Log1pGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, Log10Functor, -// Log10GradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, BReluFunctor, -// BReluGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, -// SoftReluGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, STanhFunctor, -// STanhGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, SoftplusFunctor, -// SoftplusGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, SoftsignFunctor, -// SoftsignGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, Relu6Functor, -// Relu6GradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, -// TanhShrinkGradFunctor); -// REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, -// HardShrinkGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, - HardSigmoidGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, + CudaSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, + CudaLogSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor, + CudaAtanGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor, + CudaSoftShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor, + CudaAcosGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor, + CudaAsinGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor, + CudaSinhGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor, + CudaCoshGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor, + CudaZeroGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, + CudaReciprocalGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, CudaLog1pFunctor, + CudaLog1pGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, CudaLog2Functor, + CudaLog2GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, CudaLog10Functor, + CudaLog10GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, CudaBReluFunctor, + CudaBReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, CudaSoftReluFunctor, + CudaSoftReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, CudaSTanhFunctor, + CudaSTanhGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, CudaSoftplusFunctor, + CudaSoftplusGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, CudaSoftsignFunctor, + CudaSoftsignGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, CudaRelu6Functor, + CudaRelu6GradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, + CudaTanhShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor, + CudaHardShrinkGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, + CudaHardSigmoidFunctor, + CudaHardSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(swish, Swish, CudaSwishFunctor, + CudaSwishGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(thresholded_relu, ThresholdedRelu, - ThresholdedReluFunctor, - ThresholdedReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_swish, HardSwish, HardSwishFunctor, - HardSwishGradFunctor); + CudaThresholdedReluFunctor, + CudaThresholdedReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(hard_swish, HardSwish, CudaHardSwishFunctor, + CudaHardSwishGradFunctor); From a51d16f8176565284ba325d6c364db6f386b9822 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Fri, 23 Apr 2021 10:18:36 +0000 Subject: [PATCH 04/11] fix --- paddle/fluid/operators/activation_op.cu | 16 +++++++--------- paddle/fluid/operators/activation_op.h | 4 ++-- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 84454f574c9c30..a412235aadf64c 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -623,8 +623,8 @@ struct CudaSoftReluGradFunctor : public BaseCudaActiveFunctor { CT dout = static_cast(args[0]); CT out = static_cast(args[1]); CT t = static_cast(threshold); - return (out <= -t || out >= t) ? static_cast(0.0f) - : T(dout * (one - exp(-out))); + return (out > -t && out < t) ? T(dout * (one - exp(-out))) + : static_cast(0.0f); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } @@ -804,6 +804,7 @@ struct CudaTanhShrinkGradFunctor : public BaseCudaActiveFunctor { /********************HardShrink Begin********************/ template struct CudaHardShrinkFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { @@ -813,14 +814,13 @@ struct CudaHardShrinkFunctor : public BaseCudaActiveFunctor { __device__ __forceinline__ T operator()(const T* args) const { T x = args[0]; T t = static_cast(threshold); - T temp1 = static_cast(x > t); - T temp2 = static_cast(x < -t); - return x * (temp1 + temp2); + return (x > -t && x < t) ? zero : x; } }; template struct CudaHardShrinkGradFunctor : public BaseCudaActiveFunctor { + T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { @@ -830,9 +830,7 @@ struct CudaHardShrinkGradFunctor : public BaseCudaActiveFunctor { __device__ __forceinline__ T operator()(const T* args) const { T x = args[1]; T t = static_cast(threshold); - T temp1 = static_cast(x > t); - T temp2 = static_cast(x < -t); - return args[0] * (temp1 + temp2); + return (x > -t && x < t) ? zero : args[0]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -911,7 +909,7 @@ struct CudaSwishGradFunctor : public BaseCudaActiveFunctor { CT b = static_cast(beta); CT temp1 = one / (one + exp(-b * x)); CT out = x * temp1; - CT temp2 = temp1 * (one - b * x); + CT temp2 = temp1 * (one - b * out); return T(dout * (b * out + temp2)); } diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 7245dea9cf9499..ccd5bf528ba58c 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); - out.device(d) = x * (temp1 + temp2).template cast(); + out.device(d) = x * (temp1 || temp2).template cast(); } }; @@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); - dx.device(d) = dout * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 || temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } From c17f3aa317aba5aed677301e7aacd0b9bd9be074 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Sun, 25 Apr 2021 06:50:37 +0000 Subject: [PATCH 05/11] fix --- paddle/fluid/operators/activation_op.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index a412235aadf64c..3aaa14fd3fcbe9 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1331,10 +1331,8 @@ REGISTER_OP_CUDA_KERNEL( ops::CudaExpFunctor>, ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( @@ -1361,6 +1359,7 @@ REGISTER_OP_CUDA_KERNEL( ops::LogDoubleGradKernel>); /* ========================================================================== */ + REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, From 88d29139a77fedb606a7cb608a712858069afe79 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Sun, 25 Apr 2021 07:31:24 +0000 Subject: [PATCH 06/11] add silu --- paddle/fluid/operators/activation_op.cu | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 3aaa14fd3fcbe9..ab0ec790ccc88e 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -107,6 +107,36 @@ struct CudaSigmoidGradFunctor : public BaseCudaActiveFunctor { }; /********************Sigmoid End********************/ +/********************Silu Begin********************/ +template +struct CudaSiluFunctor : public BaseCudaActiveFunctor { + // CT means Compute Type + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + CT x = static_cast(args[0]); + return T(x / (one + exp(-x))); + } +}; + +template +struct CudaSiluGradFunctor : public BaseCudaActiveFunctor { + using CT = typename details::MPTypeTrait::Type; + CT one = static_cast(1.0f); + + __device__ __forceinline__ T operator()(const T* args) const { + CT dout = static_cast(args[0]); + CT x = static_cast(args[1]); + CT temp1 = one + exp(-x); + CT temp2 = x * exp(-x); + return T(dout * ((one / temp1) * (one + temp2 / temp1))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; +/********************Silu End********************/ + /********************LogSigmoid Begin********************/ template struct CudaLogSigmoidFunctor : public BaseCudaActiveFunctor { @@ -1362,6 +1392,8 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, + CudaSiluGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, CudaLogSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor, From 95aad4b857e187059bad526eadf6ed5a8b237c2b Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Sun, 25 Apr 2021 08:01:29 +0000 Subject: [PATCH 07/11] revert swish and softrelu --- paddle/fluid/operators/activation_op.cu | 38 ++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index ab0ec790ccc88e..4ddb777cdddeb3 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1390,6 +1390,40 @@ REGISTER_OP_CUDA_KERNEL( ops::LogGradGradFunctor>); /* ========================================================================== */ +/* ========================== softrelu register ============================ + */ +REGISTER_OP_CUDA_KERNEL( + soft_relu, + ops::ActivationKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CUDA_KERNEL( + soft_relu_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>, + ops::ActivationGradKernel>); +/* ========================================================================== */ + +/* ========================== swish register ============================ */ +REGISTER_OP_CUDA_KERNEL( + swish, + ops::ActivationKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CUDA_KERNEL( + swish_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>, + ops::ActivationGradKernel>); +/* ========================================================================== */ + REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, @@ -1427,8 +1461,6 @@ REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(soft_relu, SoftRelu, CudaSoftReluFunctor, - CudaSoftReluGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, CudaSoftplusFunctor, @@ -1444,8 +1476,6 @@ REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor, REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, CudaHardSigmoidGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(swish, Swish, CudaSwishFunctor, - CudaSwishGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, CudaThresholdedReluGradFunctor); From 0d09b3efa0171cde81a2b9879c7c82b545ba0df0 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Sun, 25 Apr 2021 08:12:42 +0000 Subject: [PATCH 08/11] fix --- paddle/fluid/operators/activation_op.cu | 615 ++---------------------- 1 file changed, 41 insertions(+), 574 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 4ddb777cdddeb3..81ddfe71ddd380 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -494,28 +494,6 @@ struct CudaExpGradFunctor : public BaseCudaActiveFunctor { }; /********************Exp End********************/ -/********************Log1p Begin********************/ -template -struct CudaLog1pFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(log(one + x)); - } -}; - -template -struct CudaLog1pGradFunctor : public BaseCudaActiveFunctor { - T one = static_cast(1.0f); - __device__ __forceinline__ T operator()(const T* args) const { - return args[0] / (one + args[1]); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Log1p End********************/ - /********************Log Begin********************/ template struct CudaLogFunctor : public BaseCudaActiveFunctor { @@ -536,496 +514,6 @@ struct CudaLogGradFunctor : public BaseCudaActiveFunctor { }; /********************Log End********************/ -/********************Log2 Begin********************/ -template -struct CudaLog2Functor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(log2(x)); - } -}; - -template -struct CudaLog2GradFunctor : public BaseCudaActiveFunctor { - T log_two = static_cast(log(2)); - __device__ __forceinline__ T operator()(const T* args) const { - return args[0] / (args[1] * log_two); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Log2 End********************/ - -/********************Log10 Begin********************/ -template -struct CudaLog10Functor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(log10(x)); - } -}; - -template -struct CudaLog10GradFunctor : public BaseCudaActiveFunctor { - T log_ten = static_cast(log(10)); - __device__ __forceinline__ T operator()(const T* args) const { - return args[0] / (args[1] * log_ten); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Log10 End********************/ - -/********************BRelu Begin********************/ -template -struct CudaBReluFunctor : public BaseCudaActiveFunctor { - float t_min; - float t_max; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"t_min", &t_min}, {"t_max", &t_max}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T x = args[0]; - T t_min_cast = static_cast(t_min); - T t_max_cast = static_cast(t_max); - return (x > t_min_cast && x < t_max_cast) - ? x - : (x <= t_min_cast ? t_min_cast : t_max_cast); - } -}; - -template -struct CudaBReluGradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float t_min; - float t_max; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"t_min", &t_min}, {"t_max", &t_max}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T dout = args[0]; - T x = args[1]; - T t_min_cast = static_cast(t_min); - T t_max_cast = static_cast(t_max); - return (x <= t_min_cast || x >= t_max_cast) ? zero : dout; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************BRelu End********************/ - -/********************SoftRelu Begin********************/ -template -struct CudaSoftReluFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - CT t = static_cast(threshold); - CT temp = (x > -t && x < t) ? x : (x <= -t ? -t : t); - return T(log(one + exp(temp))); - } -}; - -template -struct CudaSoftReluGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT out = static_cast(args[1]); - CT t = static_cast(threshold); - return (out > -t && out < t) ? T(dout * (one - exp(-out))) - : static_cast(0.0f); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; -/********************SoftRelu End********************/ - -/********************STanh Begin********************/ -template -struct CudaSTanhFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - float scale_a; - float scale_b; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - CT a = static_cast(scale_a); - CT b = static_cast(scale_b); - return T(b * tanh(a * x)); - } -}; - -template -struct CudaSTanhGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float scale_a; - float scale_b; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - CT a = static_cast(scale_a); - CT b = static_cast(scale_b); - CT temp = tanh(a * x) * tanh(a * x); - return T(dout * a * b * (one - temp)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************STanh End********************/ - -/********************Softplus Begin********************/ -template -struct CudaSoftplusFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float beta; - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - CT b = static_cast(beta); - CT t = static_cast(threshold); - CT x_beta = x * beta; - return T(x_beta > t ? x : log(one + exp(x_beta)) / b); - } -}; - -template -struct CudaSoftplusGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float beta; - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - CT b = static_cast(beta); - CT t = static_cast(threshold); - CT x_beta = x * beta; - return x_beta > t ? args[0] : T(dout / (one + exp(-x_beta))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Softplus End********************/ - -/********************Softsign Begin********************/ -template -struct CudaSoftsignFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(x / (one + abs(x))); - } -}; - -template -struct CudaSoftsignGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout / ((one + abs(x)) * (one + abs(x)))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Softsign End********************/ - -/********************Relu6 Begin********************/ -template -struct CudaRelu6Functor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T t = static_cast(threshold); - return args[0] <= zero ? zero : (args[0] < t ? args[0] : t); - } -}; - -template -struct CudaRelu6GradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T t = static_cast(threshold); - return (args[1] > zero && args[1] < t) ? args[0] : zero; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; -/********************Relu6 End********************/ - -/********************TanhShrink Begin********************/ -template -struct CudaTanhShrinkFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(x - tanh(x)); - } -}; - -template -struct CudaTanhShrinkGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout * tanh(x) * tanh(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************TanhShrink End********************/ - -/********************HardShrink Begin********************/ -template -struct CudaHardShrinkFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T x = args[0]; - T t = static_cast(threshold); - return (x > -t && x < t) ? zero : x; - } -}; - -template -struct CudaHardShrinkGradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T x = args[1]; - T t = static_cast(threshold); - return (x > -t && x < t) ? zero : args[0]; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************HardShrink End********************/ - -/********************HardSigmoid Begin********************/ -template -struct CudaHardSigmoidFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - T one = static_cast(1.0f); - float slope; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"slope", &slope}, {"offset", &offset}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T temp = args[0] * static_cast(slope) + static_cast(offset); - return (temp > zero && temp < one) ? temp : (temp <= zero ? zero : one); - } -}; - -template -struct CudaHardSigmoidGradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - T one = static_cast(1.0f); - float slope; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"slope", &slope}, {"offset", &offset}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T out = args[1]; - return (out > zero && out < one) ? args[0] * static_cast(slope) : zero; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; -/********************HardSigmoid End********************/ - -/********************Swish Begin********************/ -template -struct CudaSwishFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float beta; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - CT b = static_cast(beta); - return T(x / (one + exp(-b * x))); - } -}; - -template -struct CudaSwishGradFunctor : public BaseCudaActiveFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); - float beta; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - CT b = static_cast(beta); - CT temp1 = one / (one + exp(-b * x)); - CT out = x * temp1; - CT temp2 = temp1 * (one - b * out); - return T(dout * (b * out + temp2)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************Swish End********************/ - -/********************ThresholdedRelu Begin********************/ -template -struct CudaThresholdedReluFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - return args[0] > static_cast(threshold) ? args[0] : zero; - } -}; - -template -struct CudaThresholdedReluGradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - return args[1] > static_cast(threshold) ? args[0] : zero; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************ThresholdedRelu End********************/ - -/********************HardSwish Begin********************/ -template -struct CudaHardSwishFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - float threshold; - float scale; - float offset; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T x = args[0]; - T t = static_cast(threshold); - T temp1 = x + static_cast(offset); - T temp2 = (temp1 > zero && temp1 < t) ? temp1 : (temp1 <= zero ? zero : t); - return temp2 * x / static_cast(scale); - } -}; - -template -struct CudaHardSwishGradFunctor : public BaseCudaActiveFunctor { - T zero = static_cast(0.0f); - T one = static_cast(1.0f); - T two = static_cast(2.0f); - float threshold; - float scale; - float offset; - - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - - __device__ __forceinline__ T operator()(const T* args) const { - T x = args[1]; - T o = static_cast(offset); - T s = static_cast(scale); - T temp1 = static_cast(x + o > zero); - T temp2 = static_cast(x + o < static_cast(threshold)); - return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; -/********************HardSwish End********************/ - /********************ELU Begin********************/ template struct CudaELUFunctor : public BaseCudaActiveFunctor { @@ -1196,6 +684,23 @@ class ActivationGradCudaKernel namespace ops = paddle::operators; namespace plat = paddle::platform; +#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ + grad_functor) \ + REGISTER_OP_CUDA_KERNEL( \ + act_type, ops::ActivationKernel>, \ + ops::ActivationKernel>, \ + ops::ActivationKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, ops::ActivationGradKernel>, \ + ops::ActivationGradKernel>, \ + ops::ActivationGradKernel>); + #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ @@ -1390,40 +895,6 @@ REGISTER_OP_CUDA_KERNEL( ops::LogGradGradFunctor>); /* ========================================================================== */ -/* ========================== softrelu register ============================ - */ -REGISTER_OP_CUDA_KERNEL( - soft_relu, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CUDA_KERNEL( - soft_relu_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); -/* ========================================================================== */ - -/* ========================== swish register ============================ */ -REGISTER_OP_CUDA_KERNEL( - swish, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CUDA_KERNEL( - swish_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); -/* ========================================================================== */ - REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor, @@ -1453,31 +924,27 @@ REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor, CudaReciprocalGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log1p, Log1p, CudaLog1pFunctor, - CudaLog1pGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log2, Log2, CudaLog2Functor, - CudaLog2GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(log10, Log10, CudaLog10Functor, - CudaLog10GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(brelu, BRelu, CudaBReluFunctor, - CudaBReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(stanh, STanh, CudaSTanhFunctor, - CudaSTanhGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(softplus, Softplus, CudaSoftplusFunctor, - CudaSoftplusGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(softsign, Softsign, CudaSoftsignFunctor, - CudaSoftsignGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(relu6, Relu6, CudaRelu6Functor, - CudaRelu6GradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, - CudaTanhShrinkGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor, - CudaHardShrinkGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_sigmoid, HardSigmoid, - CudaHardSigmoidFunctor, - CudaHardSigmoidGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(thresholded_relu, ThresholdedRelu, - CudaThresholdedReluFunctor, - CudaThresholdedReluGradFunctor); -REGISTER_ACTIVATION_CUDA_KERNEL(hard_swish, HardSwish, CudaHardSwishFunctor, - CudaHardSwishGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor, + SoftReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor, + SoftplusGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor, + SoftsignGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor, + TanhShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor, + HardShrinkGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, + HardSigmoidGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu, + ThresholdedReluFunctor, + ThresholdedReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor, + HardSwishGradFunctor); From f67e8a4421b02015e8ebc59e3fcadf7f54981c12 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Mon, 26 Apr 2021 11:12:11 +0000 Subject: [PATCH 09/11] add notes --- paddle/fluid/operators/activation_op.cu | 287 ++++++++++++++---------- 1 file changed, 172 insertions(+), 115 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 81ddfe71ddd380..3db64136425dcf 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -19,76 +19,73 @@ namespace paddle { namespace operators { template -struct BaseCudaActiveFunctor { - using ELEMENT_TYPE = T; - using AttrPair = std::vector>; - AttrPair GetAttrs() { return AttrPair(); } -}; - -// For forward, args[0] means the input x; -// For backward, args[0] means the input dout, args[1] means the input x or out, -// which depends on the FwdDeps; -/********************Relu Begin********************/ -template -struct CudaReluFunctor : public BaseCudaActiveFunctor { +struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); + // relu(x) = max(x, 0) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] > zero ? args[0] : zero; } }; template -struct CudaReluGradFunctor : public BaseCudaActiveFunctor { +struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); + // dx = dout * (out > 0) + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[1] > zero ? args[0] : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Relu End********************/ -/********************LeakyRelu Begin********************/ template -struct CudaLeakyReluFunctor : public BaseCudaActiveFunctor { +struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } + // leakyrelu(x) = x > 0 ? x : alpha * x + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] > zero ? args[0] : static_cast(alpha) * args[0]; } }; template -struct CudaLeakyReluGradFunctor : public BaseCudaActiveFunctor { +struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } + // dx = dout * (x > 0 ? 1 : alpha) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[1] > zero ? args[0] : static_cast(alpha) * args[0]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************LeakyRelu End********************/ -/********************Sigmoid Begin********************/ template -struct CudaSigmoidFunctor : public BaseCudaActiveFunctor { +struct CudaSigmoidFunctor : public BaseActivationFunctor { // CT means Compute Type using CT = typename details::MPTypeTrait::Type; CT one = static_cast(1.0f); + // sigmoid(x) = 1 / (1 + exp(-x)) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(one / (one + exp(-x))); @@ -96,24 +93,27 @@ struct CudaSigmoidFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSigmoidGradFunctor : public BaseCudaActiveFunctor { +struct CudaSigmoidGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // dx = dout * out * (1 - out) + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[1] * (one - args[1]); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Sigmoid End********************/ -/********************Silu Begin********************/ template -struct CudaSiluFunctor : public BaseCudaActiveFunctor { +struct CudaSiluFunctor : public BaseActivationFunctor { // CT means Compute Type using CT = typename details::MPTypeTrait::Type; CT one = static_cast(1.0f); + // silu(x) = x / (1 + exp(-x)) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(x / (one + exp(-x))); @@ -121,10 +121,13 @@ struct CudaSiluFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSiluGradFunctor : public BaseCudaActiveFunctor { +struct CudaSiluGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT one = static_cast(1.0f); + // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -135,14 +138,14 @@ struct CudaSiluGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Silu End********************/ -/********************LogSigmoid Begin********************/ template -struct CudaLogSigmoidFunctor : public BaseCudaActiveFunctor { +struct CudaLogSigmoidFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT zero = static_cast(0.0f); + // logsigmoid(x) = log(1 / (1 + exp(-x))) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); CT temp = x > zero ? zero : -x; @@ -151,10 +154,13 @@ struct CudaLogSigmoidFunctor : public BaseCudaActiveFunctor { }; template -struct CudaLogSigmoidGradFunctor : public BaseCudaActiveFunctor { +struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT zero = static_cast(0.0f); + // dx = dout * exp(-x) / (1 + exp(-x)) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -164,12 +170,12 @@ struct CudaLogSigmoidGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************LogSigmoid End********************/ -/********************Atan Begin********************/ template -struct CudaAtanFunctor : public BaseCudaActiveFunctor { +struct CudaAtanFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // atan(x) = atan(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(atan(x)); @@ -177,25 +183,30 @@ struct CudaAtanFunctor : public BaseCudaActiveFunctor { }; template -struct CudaAtanGradFunctor : public BaseCudaActiveFunctor { +struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // dx = dout / (1 + x^2) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] / (one + args[1] * args[1]); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Atan End********************/ -/********************SoftShrink Begin********************/ template -struct CudaSoftShrinkFunctor : public BaseCudaActiveFunctor { +struct CudaSoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } + // softshrink(x) = x - lambda, if x > lambda; + // x + lambda, if x < -lambda; + // 0, otherwise. + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { T x = args[0]; T l = static_cast(lambda); @@ -206,13 +217,16 @@ struct CudaSoftShrinkFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSoftShrinkGradFunctor : public BaseCudaActiveFunctor { +struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } + // dx = dout, if x > lambda or x < -lambda else 0 + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { T x = args[1]; T l = static_cast(lambda); @@ -223,56 +237,55 @@ struct CudaSoftShrinkGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************SoftShrink End********************/ -/********************Ceil Begin********************/ template -struct CudaCeilFunctor : public BaseCudaActiveFunctor { +struct CudaCeilFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // ceil(x) = ceil(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(ceil(x)); } }; -/********************Ceil End********************/ -/********************Floor Begin********************/ template -struct CudaFloorFunctor : public BaseCudaActiveFunctor { +struct CudaFloorFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // floor(x) = floor(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(floor(x)); } }; -/********************Floor End********************/ -/********************Round Begin********************/ template -struct CudaRoundFunctor : public BaseCudaActiveFunctor { +struct CudaRoundFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // round(x) = round(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(round(x)); } }; -/********************Floor End********************/ -/********************Zero Begin********************/ +// grad functor for ceil. floor and round template -struct CudaZeroGradFunctor : public BaseCudaActiveFunctor { +struct CudaZeroGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T* args) const { return static_cast(0.0f); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } }; -/********************Zero End********************/ -/********************Cos Begin********************/ template -struct CudaCosFunctor : public BaseCudaActiveFunctor { +struct CudaCosFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // cos(x) = cos(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(cos(x)); @@ -280,8 +293,11 @@ struct CudaCosFunctor : public BaseCudaActiveFunctor { }; template -struct CudaCosGradFunctor : public BaseCudaActiveFunctor { +struct CudaCosGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // dx = dout * (-sin(x)) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -290,12 +306,12 @@ struct CudaCosGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Cos End********************/ -/********************Sin Begin********************/ template -struct CudaSinFunctor : public BaseCudaActiveFunctor { +struct CudaSinFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // sin(x) = sin(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(sin(x)); @@ -303,8 +319,11 @@ struct CudaSinFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSinGradFunctor : public BaseCudaActiveFunctor { +struct CudaSinGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // dx = dout * cos(x) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -313,12 +332,12 @@ struct CudaSinGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Sin End********************/ -/********************Tan Begin********************/ template -struct CudaTanFunctor : public BaseCudaActiveFunctor { +struct CudaTanFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // tan(x) = tan(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(tan(x)); @@ -326,8 +345,11 @@ struct CudaTanFunctor : public BaseCudaActiveFunctor { }; template -struct CudaTanGradFunctor : public BaseCudaActiveFunctor { +struct CudaTanGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // dx = dout / cos(x)^2 + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -336,12 +358,12 @@ struct CudaTanGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Tan End********************/ -/********************Asin Begin********************/ template -struct CudaAsinFunctor : public BaseCudaActiveFunctor { +struct CudaAsinFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // asin(x) = asin(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(asin(x)); @@ -349,9 +371,12 @@ struct CudaAsinFunctor : public BaseCudaActiveFunctor { }; template -struct CudaAsinGradFunctor : public BaseCudaActiveFunctor { +struct CudaAsinGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT one = static_cast(1.0f); + // dx = dout / sqrt(1 - x^2) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -360,12 +385,12 @@ struct CudaAsinGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Asin End********************/ -/********************Acos Begin********************/ template -struct CudaAcosFunctor : public BaseCudaActiveFunctor { +struct CudaAcosFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // acos(x) = acos(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(acos(x)); @@ -373,9 +398,12 @@ struct CudaAcosFunctor : public BaseCudaActiveFunctor { }; template -struct CudaAcosGradFunctor : public BaseCudaActiveFunctor { +struct CudaAcosGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT one = static_cast(1.0f); + // dx = -dout / sqrt(1 - x^2) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -384,12 +412,12 @@ struct CudaAcosGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Acos End********************/ -/********************Cosh Begin********************/ template -struct CudaCoshFunctor : public BaseCudaActiveFunctor { +struct CudaCoshFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // cosh(x) = cosh(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(cosh(x)); @@ -397,8 +425,11 @@ struct CudaCoshFunctor : public BaseCudaActiveFunctor { }; template -struct CudaCoshGradFunctor : public BaseCudaActiveFunctor { +struct CudaCoshGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // dx = dout * sinh(x) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -407,12 +438,12 @@ struct CudaCoshGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Cosh End********************/ -/********************Sinh Begin********************/ template -struct CudaSinhFunctor : public BaseCudaActiveFunctor { +struct CudaSinhFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // sinh(x) = sinh(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(sinh(x)); @@ -420,8 +451,11 @@ struct CudaSinhFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSinhGradFunctor : public BaseCudaActiveFunctor { +struct CudaSinhGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // dx = dout * cosh(x) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -430,12 +464,12 @@ struct CudaSinhGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Sinh End********************/ -/********************Tanh Begin********************/ template -struct CudaTanhFunctor : public BaseCudaActiveFunctor { +struct CudaTanhFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // tanh(x) = tanh(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(tanh(x)); @@ -443,8 +477,11 @@ struct CudaTanhFunctor : public BaseCudaActiveFunctor { }; template -struct CudaTanhGradFunctor : public BaseCudaActiveFunctor { +struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // dx = dout * (1 - out^2) + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { T dout = static_cast(args[0]); T out = static_cast(args[1]); @@ -453,31 +490,34 @@ struct CudaTanhGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Tanh End********************/ -/********************Reciprocal Begin********************/ template -struct CudaReciprocalFunctor : public BaseCudaActiveFunctor { +struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // reciprocal(x) = 1 / x + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return one / args[0]; } }; template -struct CudaReciprocalGradFunctor : public BaseCudaActiveFunctor { +struct CudaReciprocalGradFunctor : public BaseActivationFunctor { + // dx = -dout * out^2 + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return -args[0] * args[1] * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Reciprocal End********************/ -/********************Exp Begin********************/ template -struct CudaExpFunctor : public BaseCudaActiveFunctor { +struct CudaExpFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // exp(x) = exp(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(exp(x)); @@ -485,19 +525,22 @@ struct CudaExpFunctor : public BaseCudaActiveFunctor { }; template -struct CudaExpGradFunctor : public BaseCudaActiveFunctor { +struct CudaExpGradFunctor : public BaseActivationFunctor { + // dx = dout * out + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Exp End********************/ -/********************Log Begin********************/ template -struct CudaLogFunctor : public BaseCudaActiveFunctor { +struct CudaLogFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // log(x) = log(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(log(x)); @@ -505,27 +548,30 @@ struct CudaLogFunctor : public BaseCudaActiveFunctor { }; template -struct CudaLogGradFunctor : public BaseCudaActiveFunctor { +struct CudaLogGradFunctor : public BaseActivationFunctor { + // dx = dout / x + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] / args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Log End********************/ -/********************ELU Begin********************/ template -struct CudaELUFunctor : public BaseCudaActiveFunctor { +struct CudaELUFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } + // elu(x) = x >= 0 ? x : alpha * (exp(x) - 1) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return x >= zero ? args[0] : T(static_cast(alpha) * (exp(x) - one)); @@ -533,16 +579,19 @@ struct CudaELUFunctor : public BaseCudaActiveFunctor { }; template -struct CudaELUGradFunctor : public BaseCudaActiveFunctor { +struct CudaELUGradFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; - typename BaseCudaActiveFunctor::AttrPair GetAttrs() { + typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } + // dx = x >= 0 ? dout : dout * alpha * exp(x) + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { CT dout = static_cast(args[0]); CT x = static_cast(args[1]); @@ -551,31 +600,34 @@ struct CudaELUGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************ELU End********************/ -/********************Square Begin********************/ template -struct CudaSquareFunctor : public BaseCudaActiveFunctor { +struct CudaSquareFunctor : public BaseActivationFunctor { + // square(x) = x * x + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] * args[0]; } }; template -struct CudaSquareGradFunctor : public BaseCudaActiveFunctor { +struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); + // dx = dout * 2 * x + // Inputs: args[0], the input dout + // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { return args[0] * two * args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -/********************Square End********************/ -/********************Sqrt Begin********************/ template -struct CudaSqrtFunctor : public BaseCudaActiveFunctor { +struct CudaSqrtFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // sqrt(x) = sqrt(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(sqrt(x)); @@ -583,20 +635,23 @@ struct CudaSqrtFunctor : public BaseCudaActiveFunctor { }; template -struct CudaSqrtGradFunctor : public BaseCudaActiveFunctor { +struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); + // dx = dout * 0.5 / out + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { return one_half * args[0] / args[1]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Sqrt End********************/ -/********************Rsqrt Begin********************/ template -struct CudaRsqrtFunctor : public BaseCudaActiveFunctor { +struct CudaRsqrtFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; + // rsqrt(x) = rsqrt(x) + // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast(args[0]); return T(rsqrt(x)); @@ -604,8 +659,11 @@ struct CudaRsqrtFunctor : public BaseCudaActiveFunctor { }; template -struct CudaRsqrtGradFunctor : public BaseCudaActiveFunctor { +struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); + // dx = dout * -0.5 / out^3 + // Inputs: args[0], the input dout + // args[1], the input out __device__ __forceinline__ T operator()(const T* args) const { T out = args[1]; return minus_one_half * args[0] * out * out * out; @@ -613,7 +671,6 @@ struct CudaRsqrtGradFunctor : public BaseCudaActiveFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -/********************Rsqrt End********************/ template class ActivationCudaKernel From 63e938dd701e795a07957474716aa0aae6a21b61 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Tue, 27 Apr 2021 03:30:42 +0000 Subject: [PATCH 10/11] fix --- paddle/fluid/operators/activation_op.cu | 297 ++++++++++++------------ 1 file changed, 144 insertions(+), 153 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 3db64136425dcf..23d309357d5705 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -80,15 +80,14 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { template struct CudaSigmoidFunctor : public BaseActivationFunctor { - // CT means Compute Type - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); // sigmoid(x) = 1 / (1 + exp(-x)) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(one / (one + exp(-x))); + MPType x = static_cast(args[0]); + return static_cast(one / (one + exp(-x))); } }; @@ -108,32 +107,31 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor { template struct CudaSiluFunctor : public BaseActivationFunctor { - // CT means Compute Type - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); + // MPType means Compute Type + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); // silu(x) = x / (1 + exp(-x)) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(x / (one + exp(-x))); + MPType x = static_cast(args[0]); + return static_cast(x / (one + exp(-x))); } }; template struct CudaSiluGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - CT temp1 = one + exp(-x); - CT temp2 = x * exp(-x); - return T(dout * ((one / temp1) * (one + temp2 / temp1))); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType temp = one / (one + exp(-x)); + return static_cast(dout * (temp * (one + x * (one - temp)))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -141,31 +139,38 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor { template struct CudaLogSigmoidFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); // logsigmoid(x) = log(1 / (1 + exp(-x))) + // For numerical stability, + // logsigmoid(x) = + // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - CT temp = x > zero ? zero : -x; - return T(-temp - log(exp(-temp) + exp(-x - temp))); + MPType x = static_cast(args[0]); + MPType temp = x > zero ? zero : -x; + return static_cast(-temp - log(exp(-temp) + exp(-x - temp))); } }; template struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); // dx = dout * exp(-x) / (1 + exp(-x)) + // For numerical stability: + // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, + // 0))) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - CT temp = x > zero ? zero : -x; - return T(dout * (exp(-x - temp) / (exp(-temp) + exp(-x - temp)))); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + MPType temp1 = x > zero ? zero : -x; + MPType temp2 = exp(-x - temp); + return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -173,18 +178,20 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { template struct CudaAtanFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // atan(x) = atan(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(atan(x)); + MPType x = static_cast(args[0]); + return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // dx = dout / (1 + x^2) // Inputs: args[0], the input dout // args[1], the input x @@ -218,6 +225,7 @@ struct CudaSoftShrinkFunctor : public BaseActivationFunctor { template struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { @@ -230,9 +238,7 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T* args) const { T x = args[1]; T l = static_cast(lambda); - T temp1 = static_cast(x > l); - T temp2 = static_cast(x < -l); - return args[0] * (temp1 + temp2); + return (x >= -l && x <= l) ? zero : args[0]; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -240,38 +246,41 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { template struct CudaCeilFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // ceil(x) = ceil(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(ceil(x)); + MPType x = static_cast(args[0]); + return static_cast(ceil(x)); } }; template struct CudaFloorFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // floor(x) = floor(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(floor(x)); + MPType x = static_cast(args[0]); + return static_cast(floor(x)); } }; template struct CudaRoundFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // round(x) = round(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(round(x)); + MPType x = static_cast(args[0]); + return static_cast(round(x)); } }; -// grad functor for ceil. floor and round +// grad functor for ceil, floor and round template struct CudaZeroGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T* args) const { @@ -283,25 +292,27 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor { template struct CudaCosFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // cos(x) = cos(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(cos(x)); + MPType x = static_cast(args[0]); + return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // dx = dout * (-sin(x)) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(-dout * sin(x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -309,25 +320,27 @@ struct CudaCosGradFunctor : public BaseActivationFunctor { template struct CudaSinFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // sin(x) = sin(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(sin(x)); + MPType x = static_cast(args[0]); + return static_cast(sin(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // dx = dout * cos(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout * cos(x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -335,25 +348,27 @@ struct CudaSinGradFunctor : public BaseActivationFunctor { template struct CudaTanFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // tan(x) = tan(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(tan(x)); + MPType x = static_cast(args[0]); + return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // dx = dout / cos(x)^2 // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout / (cos(x) * cos(x))); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -361,26 +376,28 @@ struct CudaTanGradFunctor : public BaseActivationFunctor { template struct CudaAsinFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // asin(x) = asin(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(asin(x)); + MPType x = static_cast(args[0]); + return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + // dx = dout / sqrt(1 - x^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout / sqrt(one - x * x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -388,26 +405,28 @@ struct CudaAsinGradFunctor : public BaseActivationFunctor { template struct CudaAcosFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // acos(x) = acos(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(acos(x)); + MPType x = static_cast(args[0]); + return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT one = static_cast(1.0f); + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + // dx = -dout / sqrt(1 - x^2) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(-dout / sqrt(one - x * x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -415,25 +434,27 @@ struct CudaAcosGradFunctor : public BaseActivationFunctor { template struct CudaCoshFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // cosh(x) = cosh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(cosh(x)); + MPType x = static_cast(args[0]); + return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // dx = dout * sinh(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout * sinh(x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -441,25 +462,27 @@ struct CudaCoshGradFunctor : public BaseActivationFunctor { template struct CudaSinhFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // sinh(x) = sinh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(sinh(x)); + MPType x = static_cast(args[0]); + return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // dx = dout * cosh(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return T(dout * cosh(x)); + MPType dout = static_cast(args[0]); + MPType x = static_cast(args[1]); + return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } @@ -467,18 +490,20 @@ struct CudaSinhGradFunctor : public BaseActivationFunctor { template struct CudaTanhFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // tanh(x) = tanh(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(tanh(x)); + MPType x = static_cast(args[0]); + return static_cast(tanh(x)); } }; template struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // dx = dout * (1 - out^2) // Inputs: args[0], the input dout // args[1], the input out @@ -494,6 +519,7 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor { template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); + // reciprocal(x) = 1 / x // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { @@ -515,12 +541,13 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor { template struct CudaExpFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // exp(x) = exp(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(exp(x)); + MPType x = static_cast(args[0]); + return static_cast(exp(x)); } }; @@ -538,12 +565,13 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { template struct CudaLogFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // log(x) = log(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(log(x)); + MPType x = static_cast(args[0]); + return static_cast(log(x)); } }; @@ -559,48 +587,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -template -struct CudaELUFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); - CT one = static_cast(1.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // elu(x) = x >= 0 ? x : alpha * (exp(x) - 1) - // Inputs: args[0], the input x - __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return x >= zero ? args[0] : T(static_cast(alpha) * (exp(x) - one)); - } -}; - -template -struct CudaELUGradFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); - CT one = static_cast(1.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // dx = x >= 0 ? dout : dout * alpha * exp(x) - // Inputs: args[0], the input dout - // args[1], the input x - __device__ __forceinline__ T operator()(const T* args) const { - CT dout = static_cast(args[0]); - CT x = static_cast(args[1]); - return x >= zero ? args[0] : T(dout * static_cast(alpha) * exp(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x @@ -613,6 +599,7 @@ struct CudaSquareFunctor : public BaseActivationFunctor { template struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); + // dx = dout * 2 * x // Inputs: args[0], the input dout // args[1], the input x @@ -625,18 +612,20 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor { template struct CudaSqrtFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // sqrt(x) = sqrt(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(sqrt(x)); + MPType x = static_cast(args[0]); + return static_cast(sqrt(x)); } }; template struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); + // dx = dout * 0.5 / out // Inputs: args[0], the input dout // args[1], the input out @@ -649,18 +638,20 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor { template struct CudaRsqrtFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; + using MPType = typename details::MPTypeTrait::Type; + // rsqrt(x) = rsqrt(x) // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T* args) const { - CT x = static_cast(args[0]); - return T(rsqrt(x)); + MPType x = static_cast(args[0]); + return static_cast(rsqrt(x)); } }; template struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); + // dx = dout * -0.5 / out^3 // Inputs: args[0], the input dout // args[1], the input out @@ -791,7 +782,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel Date: Tue, 27 Apr 2021 03:32:25 +0000 Subject: [PATCH 11/11] fix --- paddle/fluid/operators/activation_op.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 23d309357d5705..836c5fa06f6dfe 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -169,7 +169,7 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { MPType dout = static_cast(args[0]); MPType x = static_cast(args[1]); MPType temp1 = x > zero ? zero : -x; - MPType temp2 = exp(-x - temp); + MPType temp2 = exp(-x - temp1); return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); }