From ad067f133edb2d3d14d9f2258fbcb5b3fbf9068f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 26 Jan 2024 02:18:55 +0000 Subject: [PATCH 1/7] support injective group fusion --- .../group_merge/op_with_group_merge_pass.cc | 12 ++- .../group_merge/op_with_group_merge_util.h | 4 + test/ir/pir/cinn/test_rope.py | 95 +++++++++++++++++++ 3 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 test/ir/pir/cinn/test_rope.py diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc index dfd67445a35026..8677e889ba092d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc @@ -609,13 +609,17 @@ class OpFusionPassHelper { hlir::framework::pir::CompatibleInfo::OpKind(*consumer))) { auto& consumer_group = fusion_groups_[consumer]; // second step: check producer can be fused into consumer group - VLOG(3) << "Call ConditionFunction, Producer Op Pattern : " + VLOG(3) << "Call ConditionFunction, Producer Op: [" << producer->name() + << "] Pattern : " << hlir::framework::pir::CompatibleInfo::OpKind(*producer) - << " , Consumer Group Pattern : " - << consumer_group->op_pattern_kind; + << " , Consumer Group [" << consumer->name() + << "] Pattern : " << consumer_group->op_pattern_kind; - return relation.fusion_op_kind[consumer_group->op_pattern_kind]( + bool result = relation.fusion_op_kind[consumer_group->op_pattern_kind]( producer, fusion_groups_[consumer], shape_analysis); + VLOG(3) << " CanFuse: " << result; + + return result; } return false; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h index 79262fe2cbd4d7..b117d6b7dc6304 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h @@ -302,7 +302,11 @@ inline bool horizontal_or_can_inline( return false; } } + // vertical relation: 1.can compute inline + if (producer->result(0).use_count() == 1) { + return true; + } // if (helper->GetNodeData(producer)->outlinks().size() == 1 && // helper->output_ops_set_.count(producer) == 0) { // return true; diff --git a/test/ir/pir/cinn/test_rope.py b/test/ir/pir/cinn/test_rope.py new file mode 100644 index 00000000000000..b0324bbb5e78da --- /dev/null +++ b/test/ir/pir/cinn/test_rope.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import paddle +from paddle import nn + + +def apply_to_static(net, use_cinn, input_spec=None): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, + input_spec=input_spec, + build_strategy=build_strategy, + full_graph=True, + ) + + +class RotaryPosEmb(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, q, k, cos, sin, position_ids): + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +class TestRotaryPosEmb(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + self.prepare_data() + + def prepare_data(self): + self.q = paddle.randn([1, 2048, 8, 96], dtype="float32") + self.q.stop_gradient = False + + self.k = paddle.randn([1, 2048, 8, 96], dtype="float32") + self.k.stop_gradient = False + + self.cos = paddle.randn([1, 2048, 1, 96], dtype="float32") + self.cos.stop_gradient = False + + self.sin = paddle.randn([1, 2048, 1, 96], dtype="float32") + self.sin.stop_gradient = False + + self.position_ids = paddle.arange(end=2048, dtype="int64").unsqueeze(0) + self.position_ids.stop_gradient = False + + def eval(self, use_cinn): + paddle.seed(2022) + net = RotaryPosEmb() + net.eval() + if use_cinn: + net = apply_to_static(net, use_cinn) + + out = net(self.q, self.k, self.cos, self.sin, self.position_ids) + return out + + def test_eval(self): + cinn_outs = self.eval(use_cinn=True) + # dy_outs = self.eval(use_cinn=False) + + # TODO(phlrain): Need to check result + # for cinn_out, dy_out in zip(cinn_outs, dy_outs): + # np.testing.assert_allclose( + # cinn_out.numpy(), dy_out.numpy(), atol=1e-8 + # ) + + +if __name__ == '__main__': + unittest.main() From ffee52774a7cb363cd4708c9ec3bca5893d6b676 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 26 Jan 2024 11:07:17 +0000 Subject: [PATCH 2/7] [CINN+PIR]Fix IsSupportCINN Logic --- paddle/cinn/hlir/framework/pir/utils.cc | 68 +++++++++++++++++-------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index b5e64f77cc2edd..39d63cc1519e3d 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -56,6 +56,7 @@ const std::unordered_map CompatibleInfo::OP_NAMES = { {"pd_op.split_with_num", "split"}, {"pd_op.reshape", "reshape"}, {"pd_op.expand", "broadcast_to"}, + {"pd_op.concat", "concat"}, {"cinn_op.generate_shape", "generate_shape"}, {"cinn_op.reshape", "reshape"}, {"cinn_op.scale", "scale"}, @@ -133,32 +134,62 @@ bool UnimplementOps(const ::pir::Operation& op) { bool HaveZeroDimInput(const ::pir::Operation& op) { bool have_zero_dim = false; for (size_t i = 0; i < op.num_operands(); ++i) { - auto in = op.operand_source(i); - if (in) { - if (auto tensor_type = - in.type().dyn_cast()) { + auto value = op.operand_source(i); + if (value && value.type()) { + auto value_type = value.type(); + if (value_type.isa<::pir::VectorType>() && + value_type.dyn_cast<::pir::VectorType>().size() > 0U) { + auto types = value_type.dyn_cast<::pir::VectorType>().data(); + for (auto& type : types) { + if (auto tensor_type = + value_type.dyn_cast<::pir::DenseTensorType>()) { + if (tensor_type.dims().size() == 0) { + have_zero_dim = true; + break; + } + } + } + } else if (auto tensor_type = + value_type.dyn_cast<::pir::DenseTensorType>()) { if (tensor_type.dims().size() == 0) { have_zero_dim = true; break; } + } else { + // do nothing } } } + VLOG(4) << "HaveZeroDimInput: " << have_zero_dim; + return have_zero_dim; } bool AllInputDenseTensor(const ::pir::Operation& op) { bool all_denese_tensor = true; for (size_t i = 0; i < op.num_operands(); ++i) { - auto in = op.operand_source(i); - if (in) { - if (!(in.type().isa())) { + auto value = op.operand_source(i); + if (value && value.type()) { + auto value_type = value.type(); + if (value_type.isa<::pir::VectorType>() && + value_type.dyn_cast<::pir::VectorType>().size() > 0U) { + auto types = value_type.dyn_cast<::pir::VectorType>().data(); + for (auto& type : types) { + if (!type.isa<::pir::DenseTensorType>()) { + all_denese_tensor = false; + break; + } + } + } else if (!(value_type.isa<::pir::DenseTensorType>())) { all_denese_tensor = false; break; + } else { + // do nothing } } } + VLOG(4) << "AllInputDenseTensor: " << all_denese_tensor; return all_denese_tensor; } @@ -168,22 +199,14 @@ bool IsRegisteredInCINN(const ::pir::Operation& op) { CompatibleInfo::OP_NAMES.end()) { return true; } - // After PdToCinnPass, if pd_op.reshape still exists, return false. - std::string black_op_name = - std::string(cinn::dialect::OperatorDialect::name()) + "." + - CompatibleInfo::OpName(op); - if (CompatibleInfo::OP_NAMES.find(black_op_name) != - CompatibleInfo::OP_NAMES.end()) { - VLOG(4) << "Found black op after PdToCinnPass, because it has Attribute " - "Tensor: " - << op.name(); - return false; - } return OpRegistry::Global()->Find(CompatibleInfo::OpName(op)) != nullptr; } bool IsSupportForCinn(const ::pir::Operation& op) { if (!AllInputDenseTensor(op) || HaveZeroDimInput(op) || UnimplementOps(op)) { + VLOG(4) << "Found " << op.name() + << " HaveZeroDimInput or UnimplementOps or NotAllInputDenseTensor. " + << "So mark IsSupportForCinn: " << false; return false; } auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); @@ -196,8 +219,8 @@ bool IsSupportForCinn(const ::pir::Operation& op) { OpTransInfo trans_info; bool is_support = IsRegisteredInCINN(op) && !trans_info.default_deny_ops().count(op_name); - VLOG(4) << op_name << " is_support: " << is_support << " " - << IsRegisteredInCINN(op); + VLOG(4) << op_name << " is_support: " << is_support + << " IsRegisteredInCINN: " << IsRegisteredInCINN(op); // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops if (!allow_ops.empty()) { @@ -221,7 +244,10 @@ bool IsSupportForCinn(const ::pir::Operation& op) { // Such as cinn_op.reshape, except pd_op.reshape; // 3. otherwise, it should be registered in OpRegistry; bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) { - return IsSupportForCinn(op); + bool flag = IsSupportForCinn(op); + VLOG(4) << " CompatibleInfo::IsSupportCinn of " << op.name() + << " is: " << flag; + return flag; } std::string CompatibleInfo::OpName(const ::pir::Operation& op) { From ce45ad3c835c2dea1d41e90faf64f7a9e5c07110 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 26 Jan 2024 12:08:56 +0000 Subject: [PATCH 3/7] fix comment --- paddle/cinn/hlir/framework/pir/utils.cc | 100 ++++++++++-------------- 1 file changed, 40 insertions(+), 60 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index 39d63cc1519e3d..7dfc0a0c7c8398 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -54,17 +54,9 @@ const std::unordered_map CompatibleInfo::OP_NAMES = { {"pd_op.maximum", "max"}, {"pd_op.minimum", "min"}, {"pd_op.split_with_num", "split"}, - {"pd_op.reshape", "reshape"}, {"pd_op.expand", "broadcast_to"}, - {"pd_op.concat", "concat"}, {"cinn_op.generate_shape", "generate_shape"}, - {"cinn_op.reshape", "reshape"}, - {"cinn_op.scale", "scale"}, - {"cinn_op.broadcast", "broadcast_to"}, - // The following should implement OpPattern in pd_to_cinn_pass, - // otherwise, it will be block in BuildCinnPass. - {"cinn_op.squeeze", ""}, - {"cinn_op.unsqueeze", ""}}; + {"cinn_op.broadcast", "broadcast_to"}}; namespace { using GroupOpsVec = std::vector<::pir::Operation*>; @@ -120,8 +112,8 @@ bool IsSupportForCinn(const ::pir::Operation& op); // implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them // as unimplement ops. bool UnimplementOps(const ::pir::Operation& op) { - // cinn not support uniform, the FullOp of max and min support NOT generate by - // CINN + // cinn not support uniform, the FullOp of max and min support + // NOT generate by CINN if (op.isa()) { auto out = op.result(0); if (out.use_count() > 0) { @@ -132,66 +124,54 @@ bool UnimplementOps(const ::pir::Operation& op) { } bool HaveZeroDimInput(const ::pir::Operation& op) { - bool have_zero_dim = false; + auto HasZeroDim = [](const ::pir::Type& type) { + auto tensor_type = type.dyn_cast<::pir::DenseTensorType>(); + return tensor_type && tensor_type.dims().size() == 0U; + }; + // Judge for vector + auto HasZeroDimInVT = [&](const std::vector<::pir::Type>& types) { + for (auto& type : types) { + if (HasZeroDim(type)) return true; + } + return false; + }; + for (size_t i = 0; i < op.num_operands(); ++i) { auto value = op.operand_source(i); - if (value && value.type()) { - auto value_type = value.type(); - if (value_type.isa<::pir::VectorType>() && - value_type.dyn_cast<::pir::VectorType>().size() > 0U) { - auto types = value_type.dyn_cast<::pir::VectorType>().data(); - for (auto& type : types) { - if (auto tensor_type = - value_type.dyn_cast<::pir::DenseTensorType>()) { - if (tensor_type.dims().size() == 0) { - have_zero_dim = true; - break; - } - } - } - } else if (auto tensor_type = - value_type.dyn_cast<::pir::DenseTensorType>()) { - if (tensor_type.dims().size() == 0) { - have_zero_dim = true; - break; - } - } else { - // do nothing - } + if (!value || !value.type()) continue; + if (auto vector_type = value.type().dyn_cast<::pir::VectorType>()) { + if (HasZeroDimInVT(vector_type.data())) return true; + } else if (HasZeroDim(value.type())) { + return true; } } - - VLOG(4) << "HaveZeroDimInput: " << have_zero_dim; - - return have_zero_dim; + return false; } bool AllInputDenseTensor(const ::pir::Operation& op) { - bool all_denese_tensor = true; + auto IsDenseTensor = [](const ::pir::Type& type) { + return type.isa<::pir::DenseTensorType>(); + }; + + // Judge for vector + auto IsAllDenseTensor = [&](const std::vector<::pir::Type>& types) { + for (auto& type : types) { + if (!IsDenseTensor(type)) return false; + } + return true; + }; + for (size_t i = 0; i < op.num_operands(); ++i) { auto value = op.operand_source(i); - if (value && value.type()) { - auto value_type = value.type(); - if (value_type.isa<::pir::VectorType>() && - value_type.dyn_cast<::pir::VectorType>().size() > 0U) { - auto types = value_type.dyn_cast<::pir::VectorType>().data(); - for (auto& type : types) { - if (!type.isa<::pir::DenseTensorType>()) { - all_denese_tensor = false; - break; - } - } - } else if (!(value_type.isa<::pir::DenseTensorType>())) { - all_denese_tensor = false; - break; - } else { - // do nothing - } + if (!value || !value.type()) continue; + if (auto vector_type = value.type().dyn_cast<::pir::VectorType>()) { + if (!IsAllDenseTensor(vector_type.data())) return false; + } else if (!IsDenseTensor(value.type())) { + return false; } } - VLOG(4) << "AllInputDenseTensor: " << all_denese_tensor; - return all_denese_tensor; + return true; } bool IsRegisteredInCINN(const ::pir::Operation& op) { @@ -245,7 +225,7 @@ bool IsSupportForCinn(const ::pir::Operation& op) { // 3. otherwise, it should be registered in OpRegistry; bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) { bool flag = IsSupportForCinn(op); - VLOG(4) << " CompatibleInfo::IsSupportCinn of " << op.name() + VLOG(4) << "CompatibleInfo::IsSupportCinn of " << op.name() << " is: " << flag; return flag; } From 6a14bf7730b35c655e09cb4625a5034bc7319b2b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sat, 27 Jan 2024 06:33:27 +0000 Subject: [PATCH 4/7] merge gather_nd --- .../group_merge/group_with_group_merge_util.h | 11 ++++++++--- .../group_merge/op_with_group_merge_pass.cc | 4 ++-- paddle/cinn/hlir/op/contrib/gather_nd.cc | 2 ++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h index 4381ca7234b013..f0c222d494d547 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h @@ -396,10 +396,15 @@ inline bool horizontal_with_injective( return true; } - if (!is_same_size(first, second)) { - return false; + if (horizontal_relation(first, second, OpPatternKind::kInjective)) { + return true; } - return horizontal_relation(first, second, OpPatternKind::kInjective); + + if (is_same_size(first, second)) { + return true; + } + + return true; } inline bool injective_horizontal_with_reduce( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc index 8677e889ba092d..2aa05359cbe6d0 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc @@ -659,7 +659,7 @@ GroupList OpFusionPassInternal( OpFusionPassHelper(op_list, output_op_list, shape_analysis); auto res = op_fusion_helper(); - if (VLOG_IS_ON(6)) { + if (VLOG_IS_ON(3)) { std::stringstream ss; ::pir::IrPrinter printer(ss); for (size_t i = 0; i < res.size(); ++i) { @@ -672,7 +672,7 @@ GroupList OpFusionPassInternal( ss << "\n"; } } - VLOG(6) << ss.str(); + VLOG(1) << ss.str(); } VLOG(3) << "OpFusionPass Finish...!"; diff --git a/paddle/cinn/hlir/op/contrib/gather_nd.cc b/paddle/cinn/hlir/op/contrib/gather_nd.cc index 3b9528a3a47074..23fa324aa13f70 100644 --- a/paddle/cinn/hlir/op/contrib/gather_nd.cc +++ b/paddle/cinn/hlir/op/contrib/gather_nd.cc @@ -249,6 +249,8 @@ CINN_REGISTER_HELPER(gather_nd_ops) { MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); return true; From 51c809b592b33e2145c64d8e0375401ee1d04793 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 28 Jan 2024 10:25:27 +0000 Subject: [PATCH 5/7] refactor some trick code --- .../group_with_group_merge_pass.cc | 50 +++++++++++- .../group_merge/group_with_group_merge_util.h | 6 +- .../operator/transforms/group_merge/op_node.h | 2 + .../group_merge/special_ops_fusion_rule.cc | 49 ++++++++++++ .../group_merge/special_ops_fusion_rule.h | 78 +++++++++++++++++++ 5 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.cc create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index cf1cf1131eea45..219099d2279c90 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -26,6 +26,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h" #include "paddle/phi/core/flags.h" #include "paddle/cinn/common/is_reachable_predicator.h" @@ -63,6 +64,9 @@ class FuseHelper { virtual bool HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + virtual bool InjectiveFuseInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + virtual bool ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; @@ -78,6 +82,9 @@ class FuseHelper { virtual bool ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + virtual bool ReduceFuseInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + virtual bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; @@ -121,12 +128,18 @@ class GraphGroupFuseHelper final : public FuseHelper { bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override; + bool InjectiveFuseInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + bool ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const override; bool ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const override; + bool ReduceFuseInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override; @@ -357,6 +370,12 @@ bool GraphGroupFuseHelper::HorizontalWithInjective( return horizontal_with_injective(src.GetGroup(), dst.GetGroup()); } +template +bool GraphGroupFuseHelper::InjectiveFuseInjective( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return true; +} + template bool GraphGroupFuseHelper::ElementwiseFuseReduce( const OpGroupPtr& src, const OpGroupPtr& dst) const { @@ -387,6 +406,21 @@ bool GraphGroupFuseHelper::ReduceFuseBroadcast( return reduce_fuse_broadcast(src.GetGroup(), dst.GetGroup()); } +template +bool GraphGroupFuseHelper::ReduceFuseInjective( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + bool can_all_special_ops_fused = false; + src.WalkOpNodes([&](const OpNode& op) { + can_all_special_ops_fused = + can_all_special_ops_fused && + SpecialOpsFusionRule::GetInstance().ConsumerOpAllowsFusion( + op.node(), OpPatternKind::kReduction); + }); + + return can_all_special_ops_fused && + horizontal_with_injective(src.GetGroup(), dst.GetGroup()); +} + template bool GraphGroupFuseHelper::ReduceFuseReduce( const OpGroupPtr& src, const OpGroupPtr& dst) const { @@ -790,7 +824,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { {{OpPatternKind::kInjective, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, {{OpPatternKind::kInjective, OpPatternKind::kInjective}, - &DefaultVerticalFusePass::HorizontalWithInjective}, + &DefaultVerticalFusePass::InjectiveFuseInjective}, {{OpPatternKind::kInjective, OpPatternKind::kReduction}, &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, @@ -799,7 +833,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { {{OpPatternKind::kReduction, OpPatternKind::kBroadcast}, &DefaultVerticalFusePass::ReduceFuseBroadcast}, {{OpPatternKind::kReduction, OpPatternKind::kInjective}, - &DefaultVerticalFusePass::HorizontalWithInjective}, + &DefaultVerticalFusePass::ReduceFuseInjective}, {{OpPatternKind::kReduction, OpPatternKind::kReduction}, &DefaultVerticalFusePass::ReduceFuseReduce}, }; @@ -823,6 +857,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass { return ctx->fuse_helper().HorizontalWithInjective(src, dst); } + static bool InjectiveFuseInjective(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().InjectiveFuseInjective(src, dst); + } + static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { @@ -853,6 +893,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass { return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); } + static bool ReduceFuseInjective(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseInjective(src, dst); + } + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h index f0c222d494d547..033093724b776a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h @@ -400,11 +400,7 @@ inline bool horizontal_with_injective( return true; } - if (is_same_size(first, second)) { - return true; - } - - return true; + return false; } inline bool injective_horizontal_with_reduce( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h index 949309bb881ee2..051d70c6a3a0ee 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_node.h @@ -143,6 +143,8 @@ class OpNode { return paddle::get(attr); } + ::pir::Operation* node() const { return node_; } + private: friend struct std::hash; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.cc new file mode 100644 index 00000000000000..04b61fe8877230 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" + +namespace cinn { +namespace dialect { +namespace ir { + +bool GatherNdFusionFule(const ::pir::Operation* consumer, + OpPatternKind producer_group_pattern) { + if (producer_group_pattern == OpPatternKind::kReduction) { + return false; + } + return true; +} + +bool SliceFusionFule(const ::pir::Operation* consumer, + OpPatternKind producer_group_pattern) { + if (producer_group_pattern == OpPatternKind::kReduction) { + return false; + } + return true; +} + +void SpecialOpsFusionRule::Init() { + RegisterConsumerOpRule(paddle::dialect::GatherNdOp::name(), + &GatherNdFusionFule); + RegisterConsumerOpRule(cinn::dialect::SliceOp::name(), &SliceFusionFule); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h new file mode 100644 index 00000000000000..2d378dbe1e764e --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h @@ -0,0 +1,78 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/pir/core/operation.h" + +namespace cinn { +namespace dialect { +namespace ir { + +using OpPatternKind = hlir::framework::OpPatternKind; + +class SpecialOpsFusionRule { + public: + typedef bool (*RuleFunc)(const ::pir::Operation*, OpPatternKind); + + static const SpecialOpsFusionRule& GetInstance() { + thread_local static SpecialOpsFusionRule instance; + return instance; + } + + bool ProducerOpAllowsFusion(const ::pir::Operation* producer, + OpPatternKind consumer_group_pattern) const { + auto iter = producer_op_rules_.find(producer->name()); + if (iter != producer_op_rules_.end()) { + return iter->second(producer, consumer_group_pattern); + } + return true; + } + + bool ConsumerOpAllowsFusion(const ::pir::Operation* consumer, + OpPatternKind producer_group_pattern) const { + auto iter = consumer_op_rules_.find(consumer->name()); + if (iter != consumer_op_rules_.end()) { + return iter->second(consumer, producer_group_pattern); + } + return true; + } + + private: + SpecialOpsFusionRule() { Init(); } + + SpecialOpsFusionRule(const SpecialOpsFusionRule&) = delete; + SpecialOpsFusionRule(const SpecialOpsFusionRule&&) = delete; + SpecialOpsFusionRule& operator=(const SpecialOpsFusionRule&) = delete; + + void Init(); + + void RegisterProducerOpRule(const std::string& producer_op_name, + RuleFunc rule) { + producer_op_rules_[producer_op_name] = rule; + } + + void RegisterConsumerOpRule(const std::string& consumer_op_name, + RuleFunc rule) { + consumer_op_rules_[consumer_op_name] = rule; + } + + std::map producer_op_rules_; + std::map consumer_op_rules_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn From 2889e53f082b4f5adcdbcfed4aafc3b9beb09832 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 28 Jan 2024 10:28:03 +0000 Subject: [PATCH 6/7] revert some code --- .../transforms/group_merge/group_with_group_merge_util.h | 7 +++---- .../transforms/group_merge/op_with_group_merge_pass.cc | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h index 033093724b776a..4381ca7234b013 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h @@ -396,11 +396,10 @@ inline bool horizontal_with_injective( return true; } - if (horizontal_relation(first, second, OpPatternKind::kInjective)) { - return true; + if (!is_same_size(first, second)) { + return false; } - - return false; + return horizontal_relation(first, second, OpPatternKind::kInjective); } inline bool injective_horizontal_with_reduce( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc index 2aa05359cbe6d0..8677e889ba092d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc @@ -659,7 +659,7 @@ GroupList OpFusionPassInternal( OpFusionPassHelper(op_list, output_op_list, shape_analysis); auto res = op_fusion_helper(); - if (VLOG_IS_ON(3)) { + if (VLOG_IS_ON(6)) { std::stringstream ss; ::pir::IrPrinter printer(ss); for (size_t i = 0; i < res.size(); ++i) { @@ -672,7 +672,7 @@ GroupList OpFusionPassInternal( ss << "\n"; } } - VLOG(1) << ss.str(); + VLOG(6) << ss.str(); } VLOG(3) << "OpFusionPass Finish...!"; From fb05023cec6525af365128ba3c34b4dbb69c640e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 28 Jan 2024 12:33:28 +0000 Subject: [PATCH 7/7] fix bug --- .../transforms/group_merge/group_with_group_merge_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index 219099d2279c90..7bb25062d7e094 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -410,7 +410,7 @@ template bool GraphGroupFuseHelper::ReduceFuseInjective( const OpGroupPtr& src, const OpGroupPtr& dst) const { bool can_all_special_ops_fused = false; - src.WalkOpNodes([&](const OpNode& op) { + dst.WalkOpNodes([&](const OpNode& op) { can_all_special_ops_fused = can_all_special_ops_fused && SpecialOpsFusionRule::GetInstance().ConsumerOpAllowsFusion(