diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9b1b508bc9e06b..4481b6eb0ba1a0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -618,6 +618,8 @@ const std::vector kPirMkldnnPasses{ "conv2d_transpose_bias_fuse_pass", "conv3d_bias_fuse_pass", "batch_norm_act_fuse_pass", + "matmul_elementwise_add_fuse_pass", + "matmul_activation_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass"}; const std::vector kPirCpuPasses{}; diff --git a/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc new file mode 100644 index 00000000000000..1db28281578d49 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc @@ -0,0 +1,704 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +std::set act_ops = {{paddle::dialect::AbsOp::name()}, + {paddle::dialect::GeluOp::name()}, + {paddle::dialect::HardsigmoidOp::name()}, + {paddle::dialect::HardswishOp::name()}, + {paddle::dialect::LeakyReluOp::name()}, + {paddle::dialect::MishOp::name()}, + {paddle::dialect::ReluOp::name()}, + {paddle::dialect::Relu6Op::name()}, + {paddle::dialect::SigmoidOp::name()}, + {paddle::dialect::SqrtOp::name()}, + {paddle::dialect::SwishOp::name()}, + {paddle::dialect::TanhOp::name()}}; + +std::unordered_map activation_type = { + {paddle::dialect::AbsOp::name(), "abs"}, + {paddle::dialect::GeluOp::name(), "gelu"}, + {paddle::dialect::HardsigmoidOp::name(), "hard_sigmoid"}, + {paddle::dialect::HardswishOp::name(), "hard_swish"}, + {paddle::dialect::LeakyReluOp::name(), "leaky_relu"}, + {paddle::dialect::MishOp::name(), "mish"}, + {paddle::dialect::ReluOp::name(), "relu"}, + {paddle::dialect::Relu6Op::name(), "relu6"}, + {paddle::dialect::SigmoidOp::name(), "sigmoid"}, + {paddle::dialect::SqrtOp::name(), "sqrt"}, + {paddle::dialect::SwishOp::name(), "swish"}, + {paddle::dialect::TanhOp::name(), "tanh"}}; + +class MatmulActivationFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + MatmulActivationFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + const std::string &act_type) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + act_type_(act_type) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + std::unordered_map act_attrs; + if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + act_attrs.emplace("slope", pat.Attr("fuse_alpha")); + act_attrs.emplace("offset", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + act_attrs.emplace("negative_slope", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::GeluOp::name()) { + act_attrs.emplace("approximate", pat.Attr("approximate")); + } + + const auto &act = pat.Op(act_type_, act_attrs); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + if (act_type_ == paddle::dialect::GeluOp::name()) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (result_gelu) return false; + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + if (act_type_ == paddle::dialect::HardswishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f)); + fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f)); + } else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::SwishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f)); + } else if (act_type_ == paddle::dialect::Relu6Op::name()) { + fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f)); + } + + fused_attrs.insert(std::make_pair("fuse_activation", + res.StrAttr(activation_type[act_type_]))); + fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f))); + fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f))); + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulGeluTanhFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + MatmulGeluTanhFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &act = pat.Op(paddle::dialect::GeluOp::name(), + {{"approximate", pat.Attr("approximate")}}); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (!result_gelu) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("gelu_tanh")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulClipFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + MatmulClipFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &full1 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape1")}, {"value", pat.Attr("value1")}}); + const auto &full2 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape2")}, {"value", pat.Attr("value2")}}); + pat.Tensor("min") = full1(); + pat.Tensor("max") = full2(); + + const auto &act = pat.Op(paddle::dialect::ClipOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = + act(pat.Tensor("Out"), pat.Tensor("min"), pat.Tensor("max")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("clip")}, + {"fuse_alpha", pat.Attr("value1")}, + {"fuse_beta", pat.Attr("value2")}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulActivationFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + FusedMatmulActivationFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + const std::string &act_type) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + act_type_(act_type) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + std::unordered_map act_attrs; + if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + act_attrs.emplace("slope", pat.Attr("fuse_alpha")); + act_attrs.emplace("offset", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + act_attrs.emplace("negative_slope", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::GeluOp::name()) { + act_attrs.emplace("approximate", pat.Attr("approximate")); + } + + const auto &act = pat.Op(act_type_, act_attrs); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + if (act_type_ == paddle::dialect::GeluOp::name()) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (result_gelu) return false; + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + if (act_type_ == paddle::dialect::HardswishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f)); + fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f)); + } else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::SwishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f)); + } else if (act_type_ == paddle::dialect::Relu6Op::name()) { + fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f)); + } + + fused_attrs.insert(std::make_pair("fuse_activation", + res.StrAttr(activation_type[act_type_]))); + fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f))); + fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f))); + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulGeluTanhFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + FusedMatmulGeluTanhFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + const auto &act = pat.Op(paddle::dialect::GeluOp::name(), + {{"approximate", pat.Attr("approximate")}}); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (!result_gelu) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", res.StrAttr("gelu_tanh")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulClipFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + FusedMatmulClipFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + const auto &full1 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape1")}, {"value", pat.Attr("value1")}}); + const auto &full2 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape2")}, {"value", pat.Attr("value2")}}); + pat.Tensor("min") = full1(); + pat.Tensor("max") = full2(); + + const auto &act = pat.Op(paddle::dialect::ClipOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = + act(pat.Tensor("Out"), pat.Tensor("min"), pat.Tensor("max")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", res.StrAttr("clip")}, + {"fuse_alpha", pat.Attr("value1")}, + {"fuse_beta", pat.Attr("value2")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulActivationFusePass : public pir::PatternRewritePass { + public: + MatmulActivationFusePass() + : pir::PatternRewritePass("matmul_activation_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + // std::vector bool_set = {false, true}; + int benefit_idx = 1; + for (auto act_op : act_ops) { + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + act_op)); + benefit_idx++; + } + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + for (auto act_op : act_ops) { + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + act_op)); + benefit_idx++; + } + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateMatmulActivationFusePass() { + // pd_op.matmul + pd_op.relu -> onednn_op.fused_matmul + // pd_op.matmul + pd_op.add + pd_op.relu(act) -> onednn_op.fused_matmul + + // pd_op.relu(act) -> onednn_op.fused_matmul + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(matmul_activation_fuse_pass, MatmulActivationFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h new file mode 100644 index 00000000000000..87de94566ce910 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateMatmulActivationFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index f267a2f2125644..d411110f2e16ee 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -44,5 +44,6 @@ USE_PIR_PASS(conv2d_bias_fuse_pass); USE_PIR_PASS(conv2d_transpose_bias_fuse_pass); USE_PIR_PASS(conv3d_bias_fuse_pass); USE_PIR_PASS(matmul_elementwise_add_fuse_pass); +USE_PIR_PASS(matmul_activation_fuse_pass); USE_PIR_PASS(conv_elementwise_add_mkldnn_fuse_pass); #endif diff --git a/test/ir/pir/fused_pass/onednn/test_matmul_activation_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_matmul_activation_fuse_pass.py new file mode 100644 index 00000000000000..ff619c8bd131aa --- /dev/null +++ b/test/ir/pir/fused_pass/onednn/test_matmul_activation_fuse_pass.py @@ -0,0 +1,994 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle + +paddle.enable_static() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulActFusePatternCase1(PassTest): + r''' + x y + \ / + matmul + | + relu + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.relu(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.relu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase2(PassTest): + r''' + x y + \ / + matmul + | + swish + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.swish(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.swish": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase3(PassTest): + r''' + x y + \ / + matmul + | + tanh + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.abs(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.abs": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulClipFusePatternCase4(PassTest): + r''' + x y + \ / + matmul + | + clip + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.clip(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.clip": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase5(PassTest): + r''' + x y + \ / + matmul + | + gelu + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.gelu(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.gelu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase6(PassTest): + r''' + x y + \ / + matmul + | + hardsigmoid + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.hardsigmoid(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.hardsigmoid": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase7(PassTest): + r''' + x y + \ / + matmul + | + hardswish + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.hardswish(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.hardswish": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase8(PassTest): + r''' + x y + \ / + matmul + | + leaky_relu + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.leaky_relu(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.leaky_relu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase9(PassTest): + r''' + x y + \ / + matmul + | + mish + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.mish(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.mish": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase10(PassTest): + r''' + x y + \ / + matmul + | + relu6 + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.relu6(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.relu6": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase11(PassTest): + r''' + x y + \ / + matmul + | + sigmoid + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.sigmoid(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.sigmoid": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase12(PassTest): + r''' + x y + \ / + matmul + | + sqrt + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.sqrt(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.sqrt": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulAddFusePatternCase13(PassTest): + r''' + x y + \ / + matmul + | + tanh + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.tanh(matmul_out) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.tanh": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestFusedMatmulActFusePattern(PassTest): + r''' + x y + \ / + matmul resdual(data) + \ / + add + | + relu + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + bias = paddle.static.data( + name="bias", shape=[1], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.add(matmul_out, bias) + act_out = paddle.nn.functional.relu(out) + act_out = paddle.assign(act_out) + self.pass_list = [ + 'matmul_elementwise_add_fuse_pass', + 'matmul_activation_fuse_pass', + ] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + "bias": np.random.random(1).astype("float32"), + } + self.fetch_list = [act_out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + "pd_op.relu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestFusedMatmulClipFusePattern(PassTest): + r''' + x y + \ / + matmul resdual(data) + \ / + add + | + clip + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + bias = paddle.static.data( + name="bias", shape=[1], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.add(matmul_out, bias) + act_out = paddle.clip(out) + act_out = paddle.assign(act_out) + self.pass_list = [ + 'matmul_elementwise_add_fuse_pass', + 'matmul_activation_fuse_pass', + ] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + "bias": np.random.random(1).astype("float32"), + } + self.fetch_list = [act_out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + "pd_op.clip": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestFusedMatmulsigmoidFusePattern(PassTest): + r''' + x y + \ / + matmul resdual(data) + \ / + add + | + hardsigmoid + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + bias = paddle.static.data( + name="bias", shape=[1], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.add(matmul_out, bias) + act_out = paddle.nn.functional.hardsigmoid(out) + act_out = paddle.assign(act_out) + self.pass_list = [ + 'matmul_elementwise_add_fuse_pass', + 'matmul_activation_fuse_pass', + ] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + "bias": np.random.random(1).astype("float32"), + } + self.fetch_list = [act_out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + "pd_op.hardsigmoid": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_mkldnn(), + "Test case only for OneDNN pass.", +) +class TestMatmulGeluTanhFusePatternCase14(PassTest): + r''' + x y + \ / + matmul + | + gelu + | + out + ''' + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + matmul_out = paddle.matmul(x, y) + out = paddle.nn.functional.gelu(matmul_out, approximate=True) + out = paddle.assign(out) + self.pass_list = ['matmul_activation_fuse_pass'] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.fused_matmul": 1, + "pd_op.matmul": 0, + "pd_op.gelu": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +if __name__ == "__main__": + unittest.main()