diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5e73327170b125..730aa2b4316f69 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -633,6 +633,7 @@ const std::vector kPirMkldnnPasses{ "fc_fuse_pass", "fc_onednn_enable_pass", "softplus_activation_fuse_pass", + "operator_reshape_onednn_fuse_pass", "conv_elementwise_add_onednn_fuse_pass", "conv_activation_onednn_fuse_pass", "conv_concat_activation_onednn_fuse_pass", diff --git a/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.cc new file mode 100644 index 00000000000000..96f153c0e2853e --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.cc @@ -0,0 +1,243 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.h" +#include + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +class OperatorReshapeFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string fusable_ops_; + std::string fused_ops_name_; + uint32_t benefit_; + + public: + OperatorReshapeFusePattern(const std::string &fusable_ops, + const std::string &fused_ops_name, + uint32_t benefit) + : fusable_ops_(fusable_ops), + fused_ops_name_(fused_ops_name), + benefit_(benefit) {} + + std::string name() const override { + return fusable_ops_ + "ReshapeFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + std::unordered_map op_attrs; + if (fusable_ops_ == paddle::onednn::dialect::FcOp::name()) { + op_attrs.emplace("in_num_col_dims", pat.Attr("in_num_col_dims")); + op_attrs.emplace("activation_type", pat.Attr("activation_type")); + op_attrs.emplace("padding_weights", pat.Attr("padding_weights")); + op_attrs.emplace("use_quantizer", pat.Attr("use_quantizer")); + op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type")); + op_attrs.emplace("scale_in", pat.Attr("scale_in")); + op_attrs.emplace("scale_weights", pat.Attr("scale_weights")); + op_attrs.emplace("scale_out", pat.Attr("scale_out")); + op_attrs.emplace("force_fp32_output", pat.Attr("force_fp32_output")); + op_attrs.emplace("fuse_activation", pat.Attr("fuse_activation")); + op_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + op_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + op_attrs.emplace("fused_output_scale", pat.Attr("fused_output_scale")); + op_attrs.emplace("fused_reshape2_shape", + pat.Attr("fused_reshape2_shape")); + + } else if (fusable_ops_ == + paddle::onednn::dialect::FusedTransposeOp::name()) { + op_attrs.emplace("axis", pat.Attr("axis")); + op_attrs.emplace("fused_squeeze2_axes", pat.Attr("fused_squeeze2_axes")); + op_attrs.emplace("fused_unsqueeze2_axes", + pat.Attr("fused_unsqueeze2_axes")); + op_attrs.emplace("fused_reshape2_shape", + pat.Attr("fused_reshape2_shape")); + op_attrs.emplace("scale", pat.Attr("scale")); + op_attrs.emplace("shift", pat.Attr("shift")); + op_attrs.emplace("output_data_type", pat.Attr("output_data_type")); + op_attrs.emplace("data_format", pat.Attr("data_format")); + op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type")); + } else if (fusable_ops_ == paddle::dialect::TransposeOp::name()) { + op_attrs.emplace("perm", pat.Attr("perm")); + } + + const auto &op = pat.Op(fusable_ops_, op_attrs); + + const auto &full_1 = pat.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pat.Attr("full_1_value")}}); + + const auto &reshape = pat.Op(paddle::dialect::ReshapeOp::name()); + + if (fusable_ops_ == paddle::onednn::dialect::FcOp::name()) { + op({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("Input3")}, + {&pat.Tensor("Out")}); + } else { + op({&pat.Tensor("X")}, {&pat.Tensor("Out")}); + } + + reshape({&pat.Tensor("Out"), &full_1()}, + {&pat.Tensor("ShapeOut"), &pat.Tensor("XShape")}); + + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + int num_of_minus_ones = 0; + auto reshape2_shape = + match_ctx.Attr>("full_1_value"); + for (auto item : reshape2_shape) { + if (item == 0) { + VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, " + "skipping"; + return false; + } else if (item == -1) { + ++num_of_minus_ones; + } + } + if (num_of_minus_ones > 1) { + VLOG(4) + << "Number of -1 values inside of reshape2 shouldn't be greater " + "than one in op+reshape2 oneDNN fuse pass, skipping"; + return false; + } + return true; + }); + + if (fusable_ops_ == paddle::onednn::dialect::FusedTransposeOp::name()) { + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + auto fused_unsqueeze2_axes = + match_ctx.Attr>("fused_unsqueeze2_axes"); + if (fused_unsqueeze2_axes.size() > 0) { + VLOG(4) << "Cannot do " << fusable_ops_ << " + reshape fuse, because " + << fusable_ops_ << " is already fused with unsqueeze!"; + return false; + } + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + std::unordered_map fused_op_attrs{}; + + const auto &fused_reshape2_shape = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + std::vector int_array_value; + auto shape = match_ctx.Attr>("full_1_value"); + for (auto i : shape) { + int_array_value.emplace_back(static_cast(i)); + } + return int_array_value; + }); + + if (fusable_ops_ == paddle::onednn::dialect::FcOp::name()) { + fused_op_attrs.emplace("in_num_col_dims", pat.Attr("in_num_col_dims")); + fused_op_attrs.emplace("activation_type", pat.Attr("activation_type")); + fused_op_attrs.emplace("padding_weights", pat.Attr("padding_weights")); + fused_op_attrs.emplace("use_quantizer", pat.Attr("use_quantizer")); + fused_op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type")); + fused_op_attrs.emplace("scale_in", pat.Attr("scale_in")); + fused_op_attrs.emplace("scale_weights", pat.Attr("scale_weights")); + fused_op_attrs.emplace("scale_out", pat.Attr("scale_out")); + fused_op_attrs.emplace("force_fp32_output", + pat.Attr("force_fp32_output")); + fused_op_attrs.emplace("fuse_activation", pat.Attr("fuse_activation")); + fused_op_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_op_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + fused_op_attrs.emplace("fused_output_scale", + pat.Attr("fused_output_scale")); + fused_op_attrs.emplace("fused_reshape2_shape", fused_reshape2_shape); + + } else if (fusable_ops_ == + paddle::onednn::dialect::FusedTransposeOp::name()) { + fused_op_attrs.emplace("axis", pat.Attr("axis")); + fused_op_attrs.emplace("fused_squeeze2_axes", + pat.Attr("fused_squeeze2_axes")); + fused_op_attrs.emplace("fused_unsqueeze2_axes", + pat.Attr("fused_unsqueeze2_axes")); + fused_op_attrs.emplace("fused_reshape2_shape", fused_reshape2_shape); + fused_op_attrs.emplace("scale", pat.Attr("scale")); + fused_op_attrs.emplace("shift", pat.Attr("shift")); + fused_op_attrs.emplace("output_data_type", pat.Attr("output_data_type")); + fused_op_attrs.emplace("data_format", pat.Attr("data_format")); + fused_op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type")); + + } else if (fusable_ops_ == paddle::dialect::TransposeOp::name()) { + fused_op_attrs.emplace("axis", pat.Attr("perm")); + fused_op_attrs.emplace("fused_squeeze2_axes", res.VectorInt32Attr({})); + fused_op_attrs.emplace("fused_unsqueeze2_axes", res.VectorInt32Attr({})); + fused_op_attrs.emplace("fused_reshape2_shape", fused_reshape2_shape); + fused_op_attrs.emplace("scale", res.Float32Attr(1.0f)); + fused_op_attrs.emplace("shift", res.Float32Attr(0.0f)); + fused_op_attrs.emplace("output_data_type", res.StrAttr("fp32")); + fused_op_attrs.emplace("data_format", res.StrAttr("AnyLayout")); + fused_op_attrs.emplace("mkldnn_data_type", res.StrAttr("float32")); + } + + const auto &fused_op = res.Op(fused_ops_name_, fused_op_attrs); + + if (fusable_ops_ == paddle::onednn::dialect::FcOp::name()) { + fused_op({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("Input3")}, + {&res.Tensor("ShapeOut")}); + } else { + fused_op({&res.Tensor("X")}, {&res.Tensor("ShapeOut")}); + } + } +}; + +class OperatorReshapePass : public pir::PatternRewritePass { + public: + OperatorReshapePass() + : pir::PatternRewritePass("operator_reshape_onednn_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + const std::vector fusable_ops{ + paddle::onednn::dialect::FcOp::name(), + paddle::onednn::dialect::FusedTransposeOp::name(), + paddle::dialect::TransposeOp::name(), + }; + + const std::vector fused_ops{ + paddle::onednn::dialect::FcOp::name(), + paddle::onednn::dialect::FusedTransposeOp::name(), + paddle::onednn::dialect::FusedTransposeOp::name(), + }; + int benefit_idx = 1; + int fused = 0; + for (auto op : fusable_ops) { + ps.Add(paddle::drr::Create( + context, op, fused_ops[fused++], benefit_idx)); + benefit_idx++; + } + + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateOperatorReshapeOneDNNPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(operator_reshape_onednn_fuse_pass, OperatorReshapePass); diff --git a/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.h new file mode 100644 index 00000000000000..475594096ead1c --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/operator_reshape_onednn_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateOperatorReshapeOneDNNPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index 8605dd0d9ad085..ef63849e469c82 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -55,6 +55,7 @@ USE_PIR_PASS(matmul_elementwise_add_fuse_pass); USE_PIR_PASS(matmul_activation_fuse_pass); USE_PIR_PASS(fc_onednn_enable_pass); USE_PIR_PASS(softplus_activation_fuse_pass); +USE_PIR_PASS(operator_reshape_onednn_fuse_pass); USE_PIR_PASS(conv_elementwise_add_onednn_fuse_pass); USE_PIR_PASS(conv_activation_onednn_fuse_pass); USE_PIR_PASS(conv_concat_activation_onednn_fuse_pass); diff --git a/test/ir/pir/fused_pass/onednn/test_operator_reshape_onednn_fuse_pass.py b/test/ir/pir/fused_pass/onednn/test_operator_reshape_onednn_fuse_pass.py new file mode 100644 index 00000000000000..f92a2e566a54dd --- /dev/null +++ b/test/ir/pir/fused_pass/onednn/test_operator_reshape_onednn_fuse_pass.py @@ -0,0 +1,177 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle + +paddle.enable_static() + + +class TestFusedTransposeReshapeFusePass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[4, 16, 1, 32], dtype='float32' + ) + + squeeze_out = paddle.squeeze(x, axis=[2]) + transpose = paddle.transpose(squeeze_out, [0, 1, 2]) + out = paddle.reshape( + x=transpose, shape=[np.prod(transpose.shape)] + ) + out = paddle.assign(out) + self.pass_attr_list = [ + {'squeeze_transpose_onednn_fuse_pass': {}}, + {'operator_reshape_onednn_fuse_pass': {}}, + ] + self.feeds = { + "x": np.random.random((4, 16, 1, 32)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.transpose": 0, + "pd_op.reshape": 0, + "onednn_op.fused_transpose": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +class TestTransposeReshapeFusePass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[2, 6, 4], dtype='float32' + ) + + # squeeze_out = paddle.squeeze(x, axis=[2]) + transpose = paddle.transpose(x, [0, 2, 1]) + out = paddle.reshape( + x=transpose, shape=[np.prod(transpose.shape)] + ) + out = paddle.assign(out) + self.pass_attr_list = [ + {'operator_reshape_onednn_fuse_pass': {}} + ] + self.feeds = { + "x": np.random.random((2, 6, 4)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.transpose": 0, + "pd_op.reshape": 0, + "onednn_op.fused_transpose": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +class TestFcReshapeFusePass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for x_shape in [[3, 2]]: + for w_shape in [[2, 3]]: + for y_shape in [[3], [1, 3]]: + with paddle.pir_utils.IrGuard(): + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + x = paddle.static.data( + name='x', shape=x_shape, dtype='float32' + ) + w = paddle.static.data( + name='w', shape=w_shape, dtype='float32' + ) + y = paddle.static.data( + name='y', shape=y_shape, dtype='float32' + ) + fc = paddle.add(paddle.matmul(x, w), y) + out = paddle.reshape( + x=fc, shape=[np.prod(fc.shape)] + ) + out = paddle.assign(out) + self.pass_attr_list = [ + {'fc_fuse_pass': {}}, + {"fc_onednn_enable_pass": {}}, + {"operator_reshape_onednn_fuse_pass": {}}, + ] + self.feeds = { + "x": np.random.random(x_shape).astype( + "float32" + ), + "w": np.random.random(w_shape).astype( + "float32" + ), + "y": np.random.random(y_shape).astype( + "float32" + ), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.add": 0, + "pd_op.matmul": 0, + "pd_op.fc": 0, + "pd_op.scale": 0, + "onednn_op.fc": 1, + } + + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct() + + +if __name__ == "__main__": + unittest.main()