diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index 606d07fd598268..9c0a2e4501a729 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -846,7 +846,7 @@ std::vector NewOpMergeWithOp( cluster_result.end(), std::back_inserter(result), [](const cinn::fusion::PatternNodePtr node) { - return cinn::fusion::GetOpsInPattern(node->stmt_pattern_); + return cinn::fusion::GetOpsInPattern(node->stmt_pattern()); }); // Each stmts corresponds to each fusion op(cluster node). diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc index 9bc206c53a234c..910bb59f4a3ad4 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -586,7 +586,7 @@ std::vector OperationFusion( CHECK(fusion_nodes.size() == 1) << "Only support one fusion node in backend now."; - const auto& output = GetExprFromPattern(fusion_nodes[0]->stmt_pattern_); + const auto& output = GetExprFromPattern(fusion_nodes[0]->stmt_pattern()); VLOG(4) << "Fusion Result: output size is " << output.size(); for (const auto& expr : output) { VLOG(4) << expr; diff --git a/paddle/cinn/operator_fusion/group_cluster.h b/paddle/cinn/operator_fusion/group_cluster.h index aa545699a0d4d5..649a2a6a7dcf9f 100644 --- a/paddle/cinn/operator_fusion/group_cluster.h +++ b/paddle/cinn/operator_fusion/group_cluster.h @@ -82,7 +82,7 @@ inline std::vector> ClusterOps( for (const auto& node : result) { VLOG(4) << "\n" << node->DebugStr() << "\n" - << fusion::StmtPatternDebugStr(node->stmt_pattern_); + << fusion::StmtPatternDebugStr(node->stmt_pattern()); } return result; diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index 73008c4ec49524..a8ab68cf809b36 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/cinn/operator_fusion/pattern_graph.h" +#include #include "paddle/cinn/operator_fusion/backend/pattern.h" #include "paddle/cinn/operator_fusion/backend/pattern_fuser.h" #include "paddle/cinn/operator_fusion/frontend/pattern.h" @@ -58,7 +59,7 @@ std::vector> PatternGraph::SortByTopoOrder() { std::list> topo_queue; std::map, int> degree; for (const auto& node : all_pattern_nodes_) { - degree[node] = node->upstream_.size(); + degree[node] = node->upstream().size(); if (degree[node] == 0) { topo_queue.push_back(node); } @@ -67,7 +68,7 @@ std::vector> PatternGraph::SortByTopoOrder() { PatternNodePtr node = topo_queue.front(); topo_queue.pop_front(); res.push_back(node); - for (const auto& downstream_op : node->downstream_) { + for (const auto& downstream_op : node->downstream()) { degree[downstream_op] = degree[downstream_op] - 1; if (degree[downstream_op] == 0) { topo_queue.push_back(downstream_op); @@ -145,7 +146,6 @@ PatternGraph::PatternGraph(const std::vector>& contents, PatternNodePtr node = std::make_shared>(content); op_to_node_map[content.op] = node; all_pattern_nodes_.emplace(node); - node->sink_op_ = content.op; } for (const auto& content : contents) { @@ -156,7 +156,7 @@ PatternGraph::PatternGraph(const std::vector>& contents, ::pir::Operation* input_op = content.op->operand_source(i).defining_op(); if (op_to_node_map.find(input_op) != op_to_node_map.end()) { PatternNodePtr upstream_node = op_to_node_map[input_op]; - cur_node->upstream_.push_back(upstream_node); + cur_node->AddNodeToUpstream(upstream_node); } } @@ -169,15 +169,15 @@ PatternGraph::PatternGraph(const std::vector>& contents, ::pir::Operation* output_op = consumer_it->owner(); if (op_to_node_map.find(output_op) != op_to_node_map.end()) { PatternNodePtr downstream_node = op_to_node_map[output_op]; - cur_node->downstream_.push_back(downstream_node); + cur_node->AddNodeToDownstream(downstream_node); } } } // unique all upstream / downstream node. // c = a + a ; then add will have 2 same upstream. - cur_node->downstream_ = UniqueVectorBySet(cur_node->downstream_); - cur_node->upstream_ = UniqueVectorBySet(cur_node->upstream_); + cur_node->UniqueUpstream(); + cur_node->UniqueDownstream(); } VLOG(4) << "PatternGraph Created, pattern node size: " @@ -192,12 +192,12 @@ void PatternGraph::RemoveNode(const PatternNodePtr& node) { all_pattern_nodes_.erase(node); } - for (PatternNodePtr& upstream : node->upstream_) { - RemoveFromVector(&upstream->downstream_, node); + for (const PatternNodePtr& upstream : node->upstream()) { + upstream->RemoveNodeFromDownstream(node); } - for (PatternNodePtr& downstream : node->downstream_) { - RemoveFromVector(&downstream->upstream_, node); + for (const PatternNodePtr& downstream : node->downstream()) { + downstream->RemoveNodeFromUpstream(node); } } @@ -220,28 +220,22 @@ std::string PatternGraph::GraphInfo() const { template PatternNodePtr PatternGraph::MergeNode( - const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + const PatternNodePtr& upstream, + const PatternNodePtr& downstream, + MergePatternFn merge_pattern_fn) { PatternNodePtr merged_node = - std::make_shared>(upstream, downstream); + std::make_shared>(upstream, downstream, merge_pattern_fn); - // deal with the reference. - ExtendVector(&merged_node->upstream_, upstream->upstream_); - ExtendVector(&merged_node->upstream_, downstream->upstream_); - RemoveFromVector(&merged_node->upstream_, upstream); - - ExtendVector(&merged_node->downstream_, upstream->downstream_); - ExtendVector(&merged_node->downstream_, downstream->downstream_); - RemoveFromVector(&merged_node->downstream_, downstream); - - for (const auto& upstream_node : merged_node->upstream_) { - upstream_node->downstream_.push_back(merged_node); - RemoveFromVector(&upstream_node->downstream_, upstream); - RemoveFromVector(&upstream_node->downstream_, downstream); + // Update upstream and downstream nodes. + for (const auto& upstream_node : merged_node->upstream()) { + upstream_node->AddNodeToDownstream(merged_node); + upstream_node->RemoveNodeFromDownstream(upstream); + upstream_node->RemoveNodeFromDownstream(downstream); } - for (const auto& downstream_node : merged_node->downstream_) { - downstream_node->upstream_.push_back(merged_node); - RemoveFromVector(&downstream_node->downstream_, upstream); - RemoveFromVector(&downstream_node->downstream_, downstream); + for (const auto& downstream_node : merged_node->downstream()) { + downstream_node->AddNodeToUpstream(merged_node); + downstream_node->RemoveNodeFromDownstream(upstream); + downstream_node->RemoveNodeFromDownstream(downstream); } const auto vec_unique = [](const std::vector>& vec) { @@ -249,8 +243,16 @@ PatternNodePtr PatternGraph::MergeNode( return set.size() == vec.size(); }; - CHECK(vec_unique(merged_node->upstream_)); - CHECK(vec_unique(merged_node->downstream_)); + PADDLE_ENFORCE_EQ( + vec_unique(merged_node->upstream()), + true, + phi::errors::PreconditionNotMet( + "The upstream nodes of the merged node are not unique.")); + PADDLE_ENFORCE_EQ( + vec_unique(merged_node->downstream()), + true, + phi::errors::PreconditionNotMet( + "The downstream nodes of the merged node are not unique.")); // deal with the graph storage. AppendNode(merged_node); diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index 589235d8d76a8c..e6ba1342623491 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -17,11 +17,15 @@ #include "paddle/cinn/operator_fusion/policy/policy_manager.h" #include "paddle/cinn/operator_fusion/policy/relative_judge_policy.h" #include "paddle/cinn/operator_fusion/utils.h" +#include "paddle/common/enforce.h" namespace cinn::fusion { template using PatternNodePtrSet = std::unordered_set>; +template +using MergePatternFn = + std::function(const StmtPattern&, const StmtPattern&)>; template class PatternGraph { @@ -43,10 +47,18 @@ class PatternGraph { void AppendNode(const PatternNodePtr& node); std::string GraphInfo() const; PatternNodePtr MergeNode(const PatternNodePtr& upstream, - const PatternNodePtr& downstream); + const PatternNodePtr& downstream, + MergePatternFn merge_pattern_fn); std::vector> SortByTopoOrder(); - public: + const PatternNodePtrSet& all_pattern_nodes() const { + return all_pattern_nodes_; + } + const std::vector& outputs() const { return outputs_; } + const PolicyManager& policy_manager() const { return policy_manager_; } + const PolicyManager& topo_manager() const { return topo_manager_; } + + private: PatternNodePtrSet all_pattern_nodes_; std::vector outputs_; PolicyManager policy_manager_; @@ -79,7 +91,7 @@ struct SearchAlgorithm { } PatternNodePtr FindMatchedNode() { - for (PatternNodePtr iter_node : graph_->all_pattern_nodes_) { + for (PatternNodePtr iter_node : graph_->all_pattern_nodes()) { if (GraphMatcher()(*graph_, iter_node) && !visited_nodes.count(iter_node)) { visited_nodes.insert(iter_node); @@ -113,8 +125,8 @@ struct SearchAlgorithm { } std::optional, PatternNodePtr>> FindMatchedPair() { - for (PatternNodePtr i : graph_->all_pattern_nodes_) { - for (PatternNodePtr j : graph_->all_pattern_nodes_) { + for (PatternNodePtr i : graph_->all_pattern_nodes()) { + for (PatternNodePtr j : graph_->all_pattern_nodes()) { if (i == j) continue; const auto& pair = std::make_pair(i, j); if (GraphMatcher()(*graph_, i, j) && !visited_node_pair.count(pair)) { @@ -142,9 +154,14 @@ struct SearchAlgorithm { struct MergeReduceTreeOperation { template void operator()(PatternGraph* graph, PatternNodePtr node) { - CHECK_EQ(node->downstream_.size(), 1); - auto downstream = node->downstream_.at(0); - auto merged_node = graph->MergeNode(node, downstream); + PADDLE_ENFORCE_EQ( + node->downstream().size(), + 1, + phi::errors::PreconditionNotMet( + "The downstream of the ReduceTree node should be 1, but got %d.", + node->downstream().size())); + auto downstream = node->downstream().at(0); + auto merged_node = graph->MergeNode(node, downstream, MergePattern); graph->RemoveNode(downstream); graph->RemoveNode(node); VLOG(4) << "MergeReduceTreeOperation: \nupstream " << node->DebugStr() @@ -156,13 +173,25 @@ struct MergeReduceTreeOperation { struct MergeReduceTreeAndTrivialOperation { template void operator()(PatternGraph* graph, PatternNodePtr node) { - CHECK_EQ(node->downstream_.size(), 1); - auto downstream = node->downstream_.at(0); + PADDLE_ENFORCE_EQ( + node->downstream().size(), + 1, + phi::errors::PreconditionNotMet( + "The downstream of the ReduceTree node should be 1, but got %d.", + node->downstream().size())); + auto downstream = node->downstream().at(0); auto fake_reduce_iter_idx = - graph->policy_manager_.GetFakeReduceIterIdx(node, downstream); - PatternNodePtr merged_node = graph->MergeNode(node, downstream); - std::get>(merged_node->stmt_pattern_) - .fake_reduce_iter_idx = fake_reduce_iter_idx; + graph->policy_manager().GetFakeReduceIterIdx(node, downstream); + const auto merge_pattern_fn = [&fake_reduce_iter_idx]( + const StmtPattern& first, + const StmtPattern& secend) { + auto rt_pattern = std::get>( + MergePattern(first, secend)); + rt_pattern.fake_reduce_iter_idx = fake_reduce_iter_idx; + return rt_pattern; + }; + PatternNodePtr merged_node = + graph->MergeNode(node, downstream, merge_pattern_fn); graph->RemoveNode(downstream); graph->RemoveNode(node); VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream " @@ -174,8 +203,8 @@ struct MergeReduceTreeAndTrivialOperation { struct LiftReduceToReduceTreeOperation { template void operator()(PatternGraph* graph, PatternNodePtr node) { - const auto& reduce_pattern = ToReducePattern(node->stmt_pattern_); - node->stmt_pattern_ = ReduceTreePattern({}, reduce_pattern); + const auto& reduce_pattern = ToReducePattern(node->stmt_pattern()); + node->set_stmt_pattern(ReduceTreePattern({}, reduce_pattern)); VLOG(4) << "LiftReduceToReduceTreeOperation: \nnode " << node->DebugStr(); } }; @@ -185,24 +214,25 @@ struct MergeTrivialPatternOperation { void operator()(PatternGraph* graph, PatternNodePtr upstream) { std::vector> fusion_candidate = - upstream->downstream_; - upstream->downstream_.clear(); + upstream->downstream(); + upstream->ClearDownstream(); for (const auto& downstream : fusion_candidate) { if (std::holds_alternative>( - downstream->stmt_pattern_) || + downstream->stmt_pattern()) || std::holds_alternative>( - downstream->stmt_pattern_)) { - auto merged_node = graph->MergeNode(upstream, downstream); + downstream->stmt_pattern())) { + auto merged_node = + graph->MergeNode(upstream, downstream, MergePattern); graph->RemoveNode(downstream); VLOG(4) << "MergeTrivialPatternOperation: \nupstream " << upstream->DebugStr() << "\ndownstream " << downstream->DebugStr() << "\nmerged " << merged_node->DebugStr(); } else { - upstream->downstream_.push_back(downstream); + upstream->AddNodeToDownstream(downstream); } } - if (upstream->downstream_.empty()) { + if (upstream->downstream().empty()) { graph->RemoveNode(upstream); } } @@ -210,8 +240,34 @@ struct MergeTrivialPatternOperation { struct LiftToHorizontalFusionPatternOperation { template - void operator()(PatternGraph* graph, PatternNodePtr i) { - i->stmt_pattern_ = HorizontalFusionPattern({i->stmt_pattern_}); + void operator()(PatternGraph* graph, PatternNodePtr node) { + node->set_stmt_pattern( + HorizontalFusionPattern({node->stmt_pattern()})); + } +}; + +struct HorizontalFusionOperation { + template + void operator()(PatternGraph* graph, + const PatternNodePtr& i, + const PatternNodePtr& j) { + PADDLE_ENFORCE_EQ( + GetPatternName(i->stmt_pattern()), + HorizontalFusionPattern::name(), + phi::errors::PreconditionNotMet( + "The pattern of the first node should be HorizontalFusionPattern, " + "but got %s.", + GetPatternName(i->stmt_pattern()))); + PADDLE_ENFORCE_EQ( + GetPatternName(j->stmt_pattern()), + HorizontalFusionPattern::name(), + phi::errors::PreconditionNotMet( + "The pattern of the second node should be HorizontalFusionPattern, " + "but got %s.", + GetPatternName(j->stmt_pattern()))); + graph->MergeNode(i, j, MergePattern); + graph->RemoveNode(i); + graph->RemoveNode(j); } }; @@ -229,17 +285,18 @@ template struct StmtPatternGraphMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return GetPatternName(node->stmt_pattern_) == StmtPattern::name(); + return GetPatternName(node->stmt_pattern()) == StmtPattern::name(); } }; struct CanFuseRxTMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return (std::holds_alternative>(node->stmt_pattern_) && - !node->downstream_.empty() && - std::holds_alternative>( - node->downstream_.at(0)->stmt_pattern_)); + return ( + std::holds_alternative>(node->stmt_pattern()) && + !node->downstream().empty() && + std::holds_alternative>( + node->downstream().at(0)->stmt_pattern())); } }; @@ -247,10 +304,10 @@ struct CanFuseReduceTreeMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { return StmtPatternGraphMatcher>()(graph, node) && - !node->downstream_.empty() && + !node->downstream().empty() && std::holds_alternative>( - node->downstream_.at(0)->stmt_pattern_) && - graph.policy_manager_.CanFuse(node, node->downstream_.at(0)); + node->downstream().at(0)->stmt_pattern()) && + graph.policy_manager().CanFuse(node, node->downstream().at(0)); } }; @@ -258,10 +315,10 @@ struct CanFuseReduceTreeAndTrivialMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { return StmtPatternGraphMatcher>()(graph, node) && - !node->downstream_.empty() && + !node->downstream().empty() && std::holds_alternative>( - node->downstream_.at(0)->stmt_pattern_) && - graph.policy_manager_.CanFuse(node, node->downstream_.at(0)); + node->downstream().at(0)->stmt_pattern()) && + graph.policy_manager().CanFuse(node, node->downstream().at(0)); } }; @@ -276,45 +333,32 @@ struct HorizontalFusionConstrain { if (!StmtPatternGraphMatcher>()(graph, second)) { return false; } - const auto& first_dim = first->sink_op_->result(0) + const auto& first_dim = first->sink_op() + ->result(0) .type() .template dyn_cast() .dims(); - const auto& second_dim = second->sink_op_->result(0) + const auto& second_dim = second->sink_op() + ->result(0) .type() .template dyn_cast() .dims(); - return graph.topo_manager_.CanFuse(first, second) && + return graph.topo_manager().CanFuse(first, second) && first_dim == second_dim; } }; -struct HorizontalFusionOperation { - template - void operator()(PatternGraph* graph, - const PatternNodePtr& i, - const PatternNodePtr& j) { - CHECK(GetPatternName(i->stmt_pattern_) == - HorizontalFusionPattern::name()); - CHECK(GetPatternName(j->stmt_pattern_) == - HorizontalFusionPattern::name()); - graph->MergeNode(i, j); - graph->RemoveNode(i); - graph->RemoveNode(j); - } -}; - struct NonSinkNodeMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return !node->downstream_.empty(); + return !node->downstream().empty(); } }; struct IsOutputNodeMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - bool res = IsAnyFirstInSecond(node->sink_op_->results(), graph.outputs_); + bool res = IsAnyFirstInSecond(node->sink_op()->results(), graph.outputs()); return res; } }; @@ -331,7 +375,7 @@ template struct DownstreamSmallerThan { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return node->downstream_.size() < N; + return node->downstream().size() < N; } }; diff --git a/paddle/cinn/operator_fusion/pattern_node.h b/paddle/cinn/operator_fusion/pattern_node.h index d6c9f8202669ef..459522b8341de2 100644 --- a/paddle/cinn/operator_fusion/pattern_node.h +++ b/paddle/cinn/operator_fusion/pattern_node.h @@ -23,15 +23,27 @@ namespace cinn::fusion { template struct PatternNode { using PatternNodePtr = std::shared_ptr>; + using MergePatternFn = std::function(const StmtPattern&, + const StmtPattern&)>; explicit PatternNode(const PatternContent& content) : sink_op_(content.op), stmt_pattern_(ConvertToStmtPattern(content)) {} explicit PatternNode(PatternNodePtr fused_up_node, - PatternNodePtr fused_down_node) + PatternNodePtr fused_down_node, + MergePatternFn merge_pattern_fn) : sink_op_(fused_down_node->sink_op_), - stmt_pattern_(MergePattern(fused_up_node->stmt_pattern_, - fused_down_node->stmt_pattern_)) {} + stmt_pattern_(merge_pattern_fn(fused_up_node->stmt_pattern_, + fused_down_node->stmt_pattern_)) { + // Update the upstream & downstream + ExtendVector(&upstream_, fused_up_node->upstream()); + ExtendVector(&upstream_, fused_down_node->upstream()); + RemoveFromVector(&upstream_, fused_up_node); + + ExtendVector(&downstream_, fused_up_node->downstream()); + ExtendVector(&downstream_, fused_down_node->downstream()); + RemoveFromVector(&downstream_, fused_down_node); + } std::string DebugStr() const { std::stringstream ss; @@ -47,6 +59,27 @@ struct PatternNode { return ss.str(); } + pir::Operation* sink_op() const { return sink_op_; } + const StmtPattern& stmt_pattern() const { return stmt_pattern_; } + void set_stmt_pattern(const StmtPattern& pattern) { + stmt_pattern_ = pattern; + } + const std::vector& upstream() const { return upstream_; } + const std::vector& downstream() const { return downstream_; } + void AddNodeToUpstream(PatternNodePtr node) { upstream_.push_back(node); } + void AddNodeToDownstream(PatternNodePtr node) { downstream_.push_back(node); } + void RemoveNodeFromUpstream(PatternNodePtr node) { + RemoveFromVector(&upstream_, node); + } + void RemoveNodeFromDownstream(PatternNodePtr node) { + RemoveFromVector(&downstream_, node); + } + void ClearUpstream() { upstream_.clear(); } + void ClearDownstream() { downstream_.clear(); } + void UniqueUpstream() { upstream_ = UniqueVectorBySet(upstream_); } + void UniqueDownstream() { downstream_ = UniqueVectorBySet(downstream_); } + + private: StmtPattern stmt_pattern_; pir::Operation* sink_op_; diff --git a/paddle/cinn/operator_fusion/policy/general_topo_policy.cc b/paddle/cinn/operator_fusion/policy/general_topo_policy.cc index 53d54b8fa0f65e..e4cca9804a79f2 100644 --- a/paddle/cinn/operator_fusion/policy/general_topo_policy.cc +++ b/paddle/cinn/operator_fusion/policy/general_topo_policy.cc @@ -24,7 +24,7 @@ template bool IsDownstreamNode(const PatternNodePtr start, const PatternNodePtr target) { if (start == target) return true; - for (const auto& down_node : start->downstream_) { + for (const auto& down_node : start->downstream()) { if (IsDownstreamNode(down_node, target)) return true; } return false; @@ -33,7 +33,7 @@ bool IsDownstreamNode(const PatternNodePtr start, template bool IsIndirectDownstreamNode(const PatternNodePtr start, const PatternNodePtr target) { - for (const auto& node : start->downstream_) { + for (const auto& node : start->downstream()) { if (node == target) continue; if (IsDownstreamNode(node, target)) return true; } diff --git a/paddle/cinn/operator_fusion/policy/relative_judge_policy.cc b/paddle/cinn/operator_fusion/policy/relative_judge_policy.cc index 954593778a7b7f..626f54c215b6ea 100644 --- a/paddle/cinn/operator_fusion/policy/relative_judge_policy.cc +++ b/paddle/cinn/operator_fusion/policy/relative_judge_policy.cc @@ -110,12 +110,12 @@ template bool RelativeJudgePolicy::ReduceTreeGrownCanMerge( const PatternNodePtr& upstream, const PatternNodePtr& downstream) { const auto& upstream_tree = - std::get>(upstream->stmt_pattern_); - VLOG(4) << "upstream->stmt_pattern_:" + std::get>(upstream->stmt_pattern()); + VLOG(4) << "upstream->stmt_pattern():" << OpsDebugStr(GetOpsInPattern(upstream_tree)); const auto& downstream_tree = - std::get>(downstream->stmt_pattern_); - VLOG(4) << "downstream->stmt_pattern_" + std::get>(downstream->stmt_pattern()); + VLOG(4) << "downstream->stmt_pattern()" << OpsDebugStr(GetOpsInPattern(downstream_tree)); const auto& maybe_downstream_op = GetDownstreamFromCandidate( upstream_tree.GetRootPattern(), downstream_tree.FlattenReducePattern()); @@ -202,7 +202,7 @@ std::vector RelativeJudgePolicy::getUpstreamReduceDims( ShardableAxesInfoManager& axes_info) { // NOLINT const auto& split_reduce_input_dims_result = SplitReduceInputDimsIfRelatedWithNonReduceAxis( - axes_info.GetSignature(upstream->sink_op_), upstream->sink_op_); + axes_info.GetSignature(upstream->sink_op()), upstream->sink_op()); return split_reduce_input_dims_result.non_related; } @@ -213,11 +213,11 @@ std::vector RelativeJudgePolicy::getDownstreamUnrelatedDims( ShardableAxesInfoManager& axes_info) { // NOLINT const auto& split_reduce_output_dims_result = SplitReduceOutputDimsIfRelatedWithNonReduceAxis( - axes_info.GetSignature(upstream->sink_op_), upstream->sink_op_); + axes_info.GetSignature(upstream->sink_op()), upstream->sink_op()); const auto& upstream_non_reduce_dims = split_reduce_output_dims_result.related; const auto& split_trivial_dims_result = SplitDimsWithRelationship( - GetAllValueDimFromValue(downstream->sink_op_->result(0)), + GetAllValueDimFromValue(downstream->sink_op()->result(0)), upstream_non_reduce_dims); VLOG(4) << split_trivial_dims_result.DebugStr(); return split_trivial_dims_result.non_related; @@ -237,8 +237,10 @@ bool RelativeJudgePolicy::ReducePlusTrivialCanMerge( return res; } -static std::vector GatherDimsExcept( - const std::vector& dims, const std::vector& except) { +namespace { + +std::vector GatherDimsExcept(const std::vector& dims, + const std::vector& except) { std::vector result; for (size_t i = 0; i < dims.size(); i++) { if (std::find(except.begin(), except.end(), i) == except.end()) { @@ -248,7 +250,7 @@ static std::vector GatherDimsExcept( return result; } -static symbol::DimExpr GetProductDimExprForValueDims( +symbol::DimExpr GetProductDimExprForValueDims( const std::vector& dims) { if (dims.empty()) { return 0; @@ -262,8 +264,8 @@ static symbol::DimExpr GetProductDimExprForValueDims( return shape_analysis.GetProductDimExpr(dims[0].v_, dim_idx); } -static bool IsProductSmallerOrEqual(const std::vector& first, - const std::vector& second) { +bool IsProductSmallerOrEqual(const std::vector& first, + const std::vector& second) { if (first.empty()) return true; const auto& first_product = GetProductDimExprForValueDims(first); const auto& second_product = GetProductDimExprForValueDims(second); @@ -279,15 +281,17 @@ static bool IsProductSmallerOrEqual(const std::vector& first, return shape_analysis.IsEqual(first_product, second_product); } +} // namespace + template bool RelativeJudgePolicy::IsFlattenDimSmaller( const PatternNodePtr& upstream, const PatternNodePtr& downstream) { const auto& fakes = GetFakeReduceIterIdx(upstream, downstream); VLOG(4) << "IsFlattenDimSmaller: fake is " << utils::Join(fakes, ","); const auto& downstream_free_dims = GatherDimsExcept( - GetAllValueDimFromValue(downstream->sink_op_->result(0)), fakes); + GetAllValueDimFromValue(downstream->sink_op()->result(0)), fakes); const auto& upstream_free_dims = - GetAllValueDimFromValue(upstream->sink_op_->result(0)); + GetAllValueDimFromValue(upstream->sink_op()->result(0)); bool res = IsProductSmallerOrEqual(downstream_free_dims, upstream_free_dims); VLOG(4) << "IsFlattenDimSmaller: " << res; @@ -297,12 +301,13 @@ bool RelativeJudgePolicy::IsFlattenDimSmaller( template bool RelativeJudgePolicy::CanFuse(const PatternNodePtr& upstream, const PatternNodePtr& downstream) { - if (std::holds_alternative>(upstream->stmt_pattern_) && - std::holds_alternative>(downstream->stmt_pattern_)) { + if (std::holds_alternative>(upstream->stmt_pattern()) && + std::holds_alternative>(downstream->stmt_pattern())) { return ReducePlusTrivialCanMerge(upstream, downstream); } - if (std::holds_alternative>(upstream->stmt_pattern_) && - std::holds_alternative>(downstream->stmt_pattern_)) { + if (std::holds_alternative>(upstream->stmt_pattern()) && + std::holds_alternative>( + downstream->stmt_pattern())) { return ReduceTreeGrownCanMerge(upstream, downstream); } return true; // other case. @@ -311,15 +316,15 @@ bool RelativeJudgePolicy::CanFuse(const PatternNodePtr& upstream, template std::vector RelativeJudgePolicy::GetFakeReduceIterIdx( const PatternNodePtr& upstream, const PatternNodePtr& downstream) { - if (!std::holds_alternative>(upstream->stmt_pattern_) && - !std::holds_alternative>(downstream->stmt_pattern_)) { + if (!std::holds_alternative>(upstream->stmt_pattern()) && + !std::holds_alternative>(downstream->stmt_pattern())) { PADDLE_THROW("Illegal Call GetFakeReduceIterIdx"); } // TODO(xiongkun): replace after fix bug in relation that if has multi path in // graph const auto& split_reduce_dims_result = // SplitReduceInputDimsIfRelatedWithNonReduceAxis( - // axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + // axes_info_.GetSignature(upstream->sink_op()), upstream->sink_op()); // const auto& upstream_reduce_dims = split_reduce_dims_result.non_related; // const auto& upstream_non_reduce_dims = split_reduce_dims_result.related; @@ -327,12 +332,12 @@ std::vector RelativeJudgePolicy::GetFakeReduceIterIdx( const auto& split_reduce_input_dims_result = SplitReduceInputDimsIfRelatedWithNonReduceAxis( - axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + axes_info_.GetSignature(upstream->sink_op()), upstream->sink_op()); VLOG(4) << split_reduce_input_dims_result.DebugStr(); const auto& upstream_reduce_dims = split_reduce_input_dims_result.non_related; const auto& split_reduce_output_dims_result = SplitReduceOutputDimsIfRelatedWithNonReduceAxis( - axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + axes_info_.GetSignature(upstream->sink_op()), upstream->sink_op()); VLOG(4) << split_reduce_input_dims_result.DebugStr(); const auto& upstream_non_reduce_dims = split_reduce_output_dims_result.related; @@ -340,7 +345,7 @@ std::vector RelativeJudgePolicy::GetFakeReduceIterIdx( // ======================= const auto& split_trivial_dims_result = SplitDimsWithRelationship( - GetAllValueDimFromValue(downstream->sink_op_->result(0)), + GetAllValueDimFromValue(downstream->sink_op()->result(0)), upstream_non_reduce_dims); const auto& trivial_reorder_dims = split_trivial_dims_result.non_related; diff --git a/paddle/cinn/operator_fusion/policy/shardable_axes_policy.cc b/paddle/cinn/operator_fusion/policy/shardable_axes_policy.cc index 24ffa6d862c863..4b8b758f449cd5 100644 --- a/paddle/cinn/operator_fusion/policy/shardable_axes_policy.cc +++ b/paddle/cinn/operator_fusion/policy/shardable_axes_policy.cc @@ -62,9 +62,9 @@ bool ShardableAxesRRFusePolicy::ReduceTreeGrownCanMerge( return false; } const auto& upstream_tree = - std::get(upstream->stmt_pattern_); + std::get(upstream->stmt_pattern()); const auto& downstream_tree = - std::get(downstream->stmt_pattern_); + std::get(downstream->stmt_pattern()); const auto& maybe_downstream_op = GetDownstreamFromCandidate( upstream_tree.GetRootPattern(), downstream_tree.reduce_patterns_); if (!maybe_downstream_op.has_value()) {