diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index d3a2a92f6e940c..1e3de851ecd46a 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -107,10 +107,11 @@ void PatternGraph::HorizontalFusion() { StmtPatternGraphMatcher>>, LiftToHorizontalFusionPatternOperation>(this); - GraphTransformer, - HorizontalFusionOperation>(this); + GraphTransformer< + NodePairPattern, + T, + And, HorizontalCheckMiddleOutputVar>, + HorizontalFusionOperation>(this); } template diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index e113af683b53d9..0015b31da7ac9d 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -328,6 +328,21 @@ struct CanFuseReduceTreeAndTrivialMatcher { } }; +template +struct HorizontalCheckMiddleOutputVar { + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + for (const auto& i : lhs->downstream()) { + if (i == rhs) return false; + } + for (const auto& i : lhs->upstream()) { + if (i == rhs) return false; + } + return true; + } +}; + template struct HorizontalFusionConstrain { bool operator()(const PatternGraph& graph, @@ -381,11 +396,34 @@ struct DownstreamSmallerThan { } }; -template -struct And { +template +struct And {}; + +template +struct And { + template + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return A()(graph, node); + } + template + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + return A()(graph, lhs, rhs); + } +}; + +template +struct And { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return A()(graph, node) && B()(graph, node); + return A()(graph, node) && And()(graph, node); + } + template + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + return A()(graph, lhs, rhs) && And()(graph, lhs, rhs); } }; diff --git a/test/ir/pir/cinn/inference/test_llama_postprocess.py b/test/ir/pir/cinn/inference/test_llama_postprocess.py index 1600a3a794409f..cfff921719f955 100644 --- a/test/ir/pir/cinn/inference/test_llama_postprocess.py +++ b/test/ir/pir/cinn/inference/test_llama_postprocess.py @@ -90,8 +90,8 @@ def prepare_data(self): self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64") def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 4) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 4}) + utils.check_jit_kernel_number(static_fn, 7) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 7}) def eval(self, use_cinn): paddle.seed(2024) diff --git a/test/ir/pir/cinn/test_horizontal_fusion.py b/test/ir/pir/cinn/test_horizontal_fusion.py index 38a073e55323a0..e681c1bb781236 100644 --- a/test/ir/pir/cinn/test_horizontal_fusion.py +++ b/test/ir/pir/cinn/test_horizontal_fusion.py @@ -39,8 +39,8 @@ def prepare_data(self): self.x.stop_gradient = True def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 1) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + utils.check_jit_kernel_number(static_fn, 2) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2}) def eval(self, use_cinn): net = HorizontalSubGraph()