From c9ced59d7f1f0117813b739223694643238a4fe7 Mon Sep 17 00:00:00 2001 From: "Xinyi1, Li" Date: Mon, 15 Apr 2024 13:51:58 +0800 Subject: [PATCH 1/4] add softplus_activation_fuse_pass --- .../inference/api/paddle_pass_builder.cc | 1 + .../onednn/softplus_activation_fuse_pass.cc | 285 ++++++++++++++++++ .../onednn/softplus_activation_fuse_pass.h | 26 ++ paddle/fluid/pir/transforms/passes.h | 1 + .../test_softplus_activation_fuse_pass.py | 210 +++++++++++++ 5 files changed, 523 insertions(+) create mode 100644 paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc create mode 100644 paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.h create mode 100644 test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e503f1133cb7bd..6f6c0be6419a41 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -635,6 +635,7 @@ const std::vector kPirMkldnnPasses{ "matmul_transpose_reshape_fuse_pass", "matmul_elementwise_add_fuse_pass", "matmul_activation_fuse_pass", + "softplus_activation_fuse_pass", "conv_elementwise_add_onednn_fuse_pass", "conv_activation_onednn_fuse_pass", "conv_concat_activation_onednn_fuse_pass"}; diff --git a/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc new file mode 100644 index 00000000000000..054c740ade3719 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc @@ -0,0 +1,285 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +std::set act_ops = {{paddle::dialect::AbsOp::name()}, + {paddle::dialect::GeluOp::name()}, + {paddle::dialect::HardsigmoidOp::name()}, + {paddle::dialect::HardswishOp::name()}, + {paddle::dialect::LeakyReluOp::name()}, + {paddle::dialect::MishOp::name()}, + {paddle::dialect::ReluOp::name()}, + {paddle::dialect::Relu6Op::name()}, + {paddle::dialect::SigmoidOp::name()}, + {paddle::dialect::SqrtOp::name()}, + {paddle::dialect::SwishOp::name()}, + {paddle::dialect::TanhOp::name()}}; + +std::unordered_map activation_type = { + {paddle::dialect::AbsOp::name(), "abs"}, + {paddle::dialect::GeluOp::name(), "gelu"}, + {paddle::dialect::HardsigmoidOp::name(), "hard_sigmoid"}, + {paddle::dialect::HardswishOp::name(), "hard_swish"}, + {paddle::dialect::LeakyReluOp::name(), "leaky_relu"}, + {paddle::dialect::MishOp::name(), "mish"}, + {paddle::dialect::ReluOp::name(), "relu"}, + {paddle::dialect::Relu6Op::name(), "relu6"}, + {paddle::dialect::SigmoidOp::name(), "sigmoid"}, + {paddle::dialect::SqrtOp::name(), "sqrt"}, + {paddle::dialect::SwishOp::name(), "swish"}, + {paddle::dialect::TanhOp::name(), "tanh"}}; + +class SoftplusActivationFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string softplus_name_; + std::string fused_softplus_name_; + uint32_t benefit_; + std::string act_type_; + + public: + SoftplusActivationFusePattern(const std::string &softplus_name, + const std::string &fused_softplus_name, + uint32_t benefit, + const std::string &act_type) + : softplus_name_(softplus_name), + fused_softplus_name_(fused_softplus_name), + benefit_(benefit), + act_type_(act_type) {} + + std::string name() const override { return "SoftplusActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &softplus = pat.Op( + softplus_name_, + {{"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}}); + + std::unordered_map act_attrs; + if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + act_attrs.emplace("slope", pat.Attr("fuse_alpha")); + act_attrs.emplace("offset", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + act_attrs.emplace("negative_slope", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::GeluOp::name()) { + act_attrs.emplace("approximate", pat.Attr("approximate")); + } + + const auto &act = pat.Op(act_type_, act_attrs); + softplus({&pat.Tensor("x")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + if (act_type_ == paddle::dialect::GeluOp::name()) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (result_gelu) return false; + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}}; + + if (act_type_ == paddle::dialect::HardswishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f)); + fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f)); + } else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::SwishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f)); + } else if (act_type_ == paddle::dialect::Relu6Op::name()) { + fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f)); + } + + fused_attrs.insert(std::make_pair("fuse_activation", + res.StrAttr(activation_type[act_type_]))); + fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f))); + fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f))); + + const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs); + + fused_softplus({&res.Tensor("x")}, {&res.Tensor("act_out")}); + } +}; + +class SoftplusGeluTanhFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string softplus_name_; + std::string fused_softplus_name_; + uint32_t benefit_; + + public: + SoftplusGeluTanhFusePattern(const std::string &softplus_name, + const std::string &fused_softplus_name, + uint32_t benefit) + : softplus_name_(softplus_name), + fused_softplus_name_(fused_softplus_name), + benefit_(benefit) {} + + std::string name() const override { return "SoftplusActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &softplus = pat.Op( + softplus_name_, + {{"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}}); + + const auto &act = pat.Op(paddle::dialect::GeluOp::name(), + {{"approximate", pat.Attr("approximate")}}); + softplus({&pat.Tensor("x")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (!result_gelu) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"beta", pat.Attr("beta")}, + {"threshold", pat.Attr("threshold")}, + {"fuse_activation", res.StrAttr("gelu_tanh")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}}; + + const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs); + + fused_softplus({&res.Tensor("x")}, {&res.Tensor("act_out")}); + } +}; + +class SoftplusClipFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string softplus_name_; + std::string fused_softplus_name_; + uint32_t benefit_; + + public: + SoftplusClipFusePattern(const std::string &softplus_name, + const std::string &fused_softplus_name, + uint32_t benefit) + : softplus_name_(softplus_name), + fused_softplus_name_(fused_softplus_name), + benefit_(benefit) {} + + std::string name() const override { return "SoftplusActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &softplus = pat.Op( + softplus_name_, + {{"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}}); + + const auto &full1 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape1")}, {"value", pat.Attr("value1")}}); + const auto &full2 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape2")}, {"value", pat.Attr("value2")}}); + pat.Tensor("min") = full1(); + pat.Tensor("max") = full2(); + + const auto &act = pat.Op(paddle::dialect::ClipOp::name()); + softplus({&pat.Tensor("x")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = + act(pat.Tensor("Out"), pat.Tensor("min"), pat.Tensor("max")); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"beta", pat.Attr("beta")}, + {"threshold", pat.Attr("threshold")}, + {"fuse_activation", res.StrAttr("clip")}, + {"fuse_alpha", pat.Attr("value1")}, + {"fuse_beta", pat.Attr("value2")}}; + + const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs); + + fused_softplus({&res.Tensor("x")}, {&res.Tensor("act_out")}); + } +}; + +class SoftplusActivationFusePass : public pir::PatternRewritePass { + public: + SoftplusActivationFusePass() + : pir::PatternRewritePass("softplus_activation_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + int benefit_idx = 1; + // There is no pattern for "fused_softplus + activation" since currently no + // pass will output fused_softplus. We will add fused patterns when such + // pass exists. + for (auto act_op : act_ops) { + ps.Add(paddle::drr::Create( + context, + paddle::dialect::SoftplusOp::name(), + paddle::onednn::dialect::SoftplusOp::name(), + benefit_idx, + act_op)); + benefit_idx++; + } + ps.Add(paddle::drr::Create( + context, + paddle::dialect::SoftplusOp::name(), + paddle::onednn::dialect::SoftplusOp::name(), + benefit_idx++)); + ps.Add(paddle::drr::Create( + context, + paddle::dialect::SoftplusOp::name(), + paddle::onednn::dialect::SoftplusOp::name(), + benefit_idx++)); + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateSoftplusActivationFusePass() { + // pd_op.softplus + pd_op.relu(act) -> onednn_op.softplus + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(softplus_activation_fuse_pass, SoftplusActivationFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.h new file mode 100644 index 00000000000000..c56cfc5f22579e --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateSoftplusActivationFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index bc15794c45ec6b..c18165de293180 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -53,6 +53,7 @@ USE_PIR_PASS(reshape_transpose_matmul_fuse_pass); USE_PIR_PASS(matmul_transpose_reshape_fuse_pass); USE_PIR_PASS(matmul_elementwise_add_fuse_pass); USE_PIR_PASS(matmul_activation_fuse_pass); +USE_PIR_PASS(softplus_activation_fuse_pass); USE_PIR_PASS(conv_elementwise_add_onednn_fuse_pass); USE_PIR_PASS(conv_activation_onednn_fuse_pass); USE_PIR_PASS(conv_concat_activation_onednn_fuse_pass); diff --git a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py new file mode 100644 index 00000000000000..e0b1d3ecd6f792 --- /dev/null +++ b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py @@ -0,0 +1,210 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle + +paddle.enable_static() + +activation_type = [ + "abs", + "gelu", + "hard_sigmoid", + "hard_swish", + "leaky_relu", + "mish", + "relu", + "relu6", + "sigmoid", + "sqrt", + "swish", + "tanh", +] + + +class TestSoftplusActivationFusePattern(PassTest): + r""" + x + | + softplus + | + act + | + out + """ + + def is_program_valid(self, program=None): + return True + + def sample_program(self): + x_shape = [3, 2] + for act_op in activation_type: + with paddle.pir_utils.IrGuard(): + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=x_shape, dtype='float32' + ) + + softplus_out = paddle.nn.functional.softplus(x) + + if act_op == "abs": + out = paddle.abs(softplus_out) + elif act_op == "gelu": + out = paddle.nn.functional.gelu(softplus_out) + elif act_op == "hard_sigmoid": + out = paddle.nn.functional.hardsigmoid(softplus_out) + elif act_op == "hard_swish": + out = paddle.nn.functional.hardswish(softplus_out) + elif act_op == "leaky_relu": + out = paddle.nn.functional.leaky_relu(softplus_out) + elif act_op == "mish": + out = paddle.nn.functional.mish(softplus_out) + elif act_op == "relu": + out = paddle.nn.functional.relu(softplus_out) + elif act_op == "relu6": + out = paddle.nn.functional.relu6(softplus_out) + elif act_op == "sigmoid": + out = paddle.nn.functional.sigmoid(softplus_out) + elif act_op == "sqrt": + out = paddle.sqrt(softplus_out) + elif act_op == "swish": + out = paddle.nn.functional.swish(softplus_out) + elif act_op == "tanh": + out = paddle.nn.functional.tanh(softplus_out) + + out = paddle.assign(out) + self.pass_list = ["softplus_activation_fuse_pass"] + self.feeds = { + "x": np.random.random(x_shape).astype("float32") + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.softplus": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + "pd_op.abs": 0, + "pd_op.gelu": 0, + "pd_op.hard_sigmoid": 0, + "pd_op.hard_swish": 0, + "pd_op.leaky_relu": 0, + "pd_op.mish": 0, + "pd_op.relu": 0, + "pd_op.relu6": 0, + "pd_op.sigmoid": 0, + "pd_op.sqrt": 0, + "pd_op.swish": 0, + "pd_op.tanh": 0, + } + + yield [main_prog, start_prog], False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +class TestSoftplusGeluTanhFusePattern(PassTest): + r""" + x + | + softplus + | + gelu_tanh + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[3, 2], dtype='float32') + softplus_out = paddle.nn.functional.softplus(x) + out = paddle.nn.functional.gelu(softplus_out, approximate=True) + out = paddle.assign(out) + self.pass_list = ['softplus_activation_fuse_pass'] + self.feeds = {"x": np.random.random((3, 2)).astype("float32")} + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.softplus": 1, + "pd_op.gelu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +class TestSoftplusClipFusePattern(PassTest): + r""" + x + | + softplus + | + clip + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[3, 2], dtype='float32') + softplus_out = paddle.nn.functional.softplus(x) + out = paddle.clip(softplus_out) + out = paddle.assign(out) + self.pass_list = ['softplus_activation_fuse_pass'] + self.feeds = {"x": np.random.random((3, 2)).astype("float32")} + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.softplus": 1, + "pd_op.clip": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +if __name__ == "__main__": + unittest.main() From 71e575919a45d62c04680847c32db80bce4fa233 Mon Sep 17 00:00:00 2001 From: "Xinyi1, Li" Date: Thu, 18 Apr 2024 08:38:56 +0800 Subject: [PATCH 2/4] change test case --- .../fused_pass/onednn/test_softplus_activation_fuse_pass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py index e0b1d3ecd6f792..575f06815f8538 100644 --- a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py +++ b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py @@ -90,7 +90,7 @@ def sample_program(self): out = paddle.nn.functional.tanh(softplus_out) out = paddle.assign(out) - self.pass_list = ["softplus_activation_fuse_pass"] + self.pass_list = [{"softplus_activation_fuse_pass": {}}] self.feeds = { "x": np.random.random(x_shape).astype("float32") } @@ -145,7 +145,7 @@ def build_ir_program(self): softplus_out = paddle.nn.functional.softplus(x) out = paddle.nn.functional.gelu(softplus_out, approximate=True) out = paddle.assign(out) - self.pass_list = ['softplus_activation_fuse_pass'] + self.pass_list = [{"softplus_activation_fuse_pass": {}}] self.feeds = {"x": np.random.random((3, 2)).astype("float32")} self.fetch_list = [out] self.valid_op_map = { @@ -187,7 +187,7 @@ def build_ir_program(self): softplus_out = paddle.nn.functional.softplus(x) out = paddle.clip(softplus_out) out = paddle.assign(out) - self.pass_list = ['softplus_activation_fuse_pass'] + self.pass_list = [{"softplus_activation_fuse_pass": {}}] self.feeds = {"x": np.random.random((3, 2)).astype("float32")} self.fetch_list = [out] self.valid_op_map = { From 5616f5756801e14e8c90c45146f081c96c26cd8a Mon Sep 17 00:00:00 2001 From: "Xinyi1, Li" Date: Thu, 18 Apr 2024 13:34:51 +0800 Subject: [PATCH 3/4] modify test case --- .../fused_pass/onednn/test_softplus_activation_fuse_pass.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py index 575f06815f8538..7b0bb8b8a2e5a8 100644 --- a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py +++ b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py @@ -90,7 +90,9 @@ def sample_program(self): out = paddle.nn.functional.tanh(softplus_out) out = paddle.assign(out) - self.pass_list = [{"softplus_activation_fuse_pass": {}}] + self.pass_attr_list = [ + {"softplus_activation_fuse_pass": {}} + ] self.feeds = { "x": np.random.random(x_shape).astype("float32") } From e720768d1e43acdb36a2978875415d62c311c1a4 Mon Sep 17 00:00:00 2001 From: "Xinyi1, Li" Date: Thu, 18 Apr 2024 14:21:04 +0800 Subject: [PATCH 4/4] fix code --- .../onednn/softplus_activation_fuse_pass.cc | 6 +++--- .../onednn/test_softplus_activation_fuse_pass.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc index 054c740ade3719..f059115aea867d 100644 --- a/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc @@ -253,7 +253,7 @@ class SoftplusActivationFusePass : public pir::PatternRewritePass { ps.Add(paddle::drr::Create( context, paddle::dialect::SoftplusOp::name(), - paddle::onednn::dialect::SoftplusOp::name(), + paddle::onednn::dialect::FusedSoftplusOp::name(), benefit_idx, act_op)); benefit_idx++; @@ -261,12 +261,12 @@ class SoftplusActivationFusePass : public pir::PatternRewritePass { ps.Add(paddle::drr::Create( context, paddle::dialect::SoftplusOp::name(), - paddle::onednn::dialect::SoftplusOp::name(), + paddle::onednn::dialect::FusedSoftplusOp::name(), benefit_idx++)); ps.Add(paddle::drr::Create( context, paddle::dialect::SoftplusOp::name(), - paddle::onednn::dialect::SoftplusOp::name(), + paddle::onednn::dialect::FusedSoftplusOp::name(), benefit_idx++)); return ps; } diff --git a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py index 7b0bb8b8a2e5a8..e4d67dd2395cb1 100644 --- a/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py +++ b/test/ir/pir/fused_pass/onednn/test_softplus_activation_fuse_pass.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy as np + +sys.path.append("../") from pass_test import PassTest import paddle @@ -98,7 +101,7 @@ def sample_program(self): } self.fetch_list = [out] self.valid_op_map = { - "onednn_op.softplus": 1, + "onednn_op.fused_softplus": 1, "pd_op.matmul": 0, "pd_op.add": 0, "pd_op.abs": 0, @@ -147,11 +150,11 @@ def build_ir_program(self): softplus_out = paddle.nn.functional.softplus(x) out = paddle.nn.functional.gelu(softplus_out, approximate=True) out = paddle.assign(out) - self.pass_list = [{"softplus_activation_fuse_pass": {}}] + self.pass_attr_list = [{'softplus_activation_fuse_pass': {}}] self.feeds = {"x": np.random.random((3, 2)).astype("float32")} self.fetch_list = [out] self.valid_op_map = { - "onednn_op.softplus": 1, + "onednn_op.fused_softplus": 1, "pd_op.gelu": 0, } return [main_prog, start_prog] @@ -189,11 +192,11 @@ def build_ir_program(self): softplus_out = paddle.nn.functional.softplus(x) out = paddle.clip(softplus_out) out = paddle.assign(out) - self.pass_list = [{"softplus_activation_fuse_pass": {}}] + self.pass_attr_list = [{'softplus_activation_fuse_pass': {}}] self.feeds = {"x": np.random.random((3, 2)).astype("float32")} self.fetch_list = [out] self.valid_op_map = { - "onednn_op.softplus": 1, + "onednn_op.fused_softplus": 1, "pd_op.clip": 0, } return [main_prog, start_prog]