Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
"The output is a multi-dimensional Tensor which has same "
"dimension and data type as the ``x``.");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC(
ELU Activation Operator.

Expand Down Expand Up @@ -743,6 +747,10 @@ class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC(
HardSwish Activation Operator.

Expand Down
25 changes: 17 additions & 8 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ template <typename T>
using AbsMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
using EluMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_elu>;

template <typename T>
using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
Expand Down Expand Up @@ -240,6 +244,10 @@ using SqrtMKLDNNGradFunctor =
template <typename T>
using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
using EluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_elu>;
} // namespace operators
} // namespace paddle

Expand All @@ -264,14 +272,15 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationGradKernel< \
ops::grad_functor<paddle::platform::bfloat16>>);

#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hardswish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor);

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ def setUp(self):
self.attrs = {"use_mkldnn": True}


class TestMKLDNNEluDefaultAlpha(TestActivation):
def setUp(self):
self.op_type = "elu"
self.set_alpha()

x = np.random.random((5, 5, 4)).astype("float32")

self.inputs = {'X': x}
self.attrs = {'use_mkldnn': True, 'alpha': self.alpha}
self.outputs = {
'Out':
np.maximum(0, x) + np.minimum(0, self.alpha * (np.exp(x) - 1))
}

def set_alpha(self):
self.alpha = 1.0


class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha):
def set_alpha(self):
self.alpha = 2.5


# Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self):
Expand Down