From 8d82e0729b60d8e2de642be927b3f9892c0d852d Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 23 May 2024 11:09:50 +0000 Subject: [PATCH 1/3] fix horizontal bugs and restore --- paddle/cinn/operator_fusion/pattern_graph.cc | 9 ++-- paddle/cinn/operator_fusion/pattern_graph.h | 54 ++++++++++++++++++-- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index d3a2a92f6e940c..1e3de851ecd46a 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -107,10 +107,11 @@ void PatternGraph::HorizontalFusion() { StmtPatternGraphMatcher>>, LiftToHorizontalFusionPatternOperation>(this); - GraphTransformer, - HorizontalFusionOperation>(this); + GraphTransformer< + NodePairPattern, + T, + And, HorizontalCheckMiddleOutputVar>, + HorizontalFusionOperation>(this); } template diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index e113af683b53d9..22ac927cb0c5d0 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -328,6 +328,31 @@ struct CanFuseReduceTreeAndTrivialMatcher { } }; +template +struct HorizontalCheckMiddleOutputVar { + bool IsAnyOpUseOutput(const std::vector& ops, + const std::vector& output_value) { + std::unordered_set set(output_value.begin(), + output_value.end()); + for (const auto& op : ops) { + for (const auto& var : op->operands()) { + if (set.count(var.source())) { + return true; + } + } + } + return false; + } + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + const auto& output_value = graph.outputs(); + const auto& ops = ConcatVector(GetOpsInPattern(lhs->stmt_pattern()), + GetOpsInPattern(rhs->stmt_pattern())); + return !IsAnyOpUseOutput(ops, output_value); + } +}; + template struct HorizontalFusionConstrain { bool operator()(const PatternGraph& graph, @@ -381,11 +406,34 @@ struct DownstreamSmallerThan { } }; -template -struct And { +template +struct And {}; + +template +struct And { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return A()(graph, node) && B()(graph, node); + return A()(graph, node); + } + template + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + return A()(graph, lhs, rhs); + } +}; + +template +struct And { + template + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return A()(graph, node) && And()(graph, node); + } + template + bool operator()(const PatternGraph& graph, + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + return A()(graph, lhs, rhs) && And()(graph, lhs, rhs); } }; From 33762e8cb91aeb636f24a4080dfab164043fc0af Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 24 May 2024 05:13:19 +0000 Subject: [PATCH 2/3] fix --- paddle/cinn/operator_fusion/pattern_graph.h | 24 ++++++------------- .../cinn/inference/test_llama_postprocess.py | 4 ++-- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index 22ac927cb0c5d0..0015b31da7ac9d 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -330,26 +330,16 @@ struct CanFuseReduceTreeAndTrivialMatcher { template struct HorizontalCheckMiddleOutputVar { - bool IsAnyOpUseOutput(const std::vector& ops, - const std::vector& output_value) { - std::unordered_set set(output_value.begin(), - output_value.end()); - for (const auto& op : ops) { - for (const auto& var : op->operands()) { - if (set.count(var.source())) { - return true; - } - } - } - return false; - } bool operator()(const PatternGraph& graph, const PatternNodePtr& lhs, const PatternNodePtr& rhs) { - const auto& output_value = graph.outputs(); - const auto& ops = ConcatVector(GetOpsInPattern(lhs->stmt_pattern()), - GetOpsInPattern(rhs->stmt_pattern())); - return !IsAnyOpUseOutput(ops, output_value); + for (const auto& i : lhs->downstream()) { + if (i == rhs) return false; + } + for (const auto& i : lhs->upstream()) { + if (i == rhs) return false; + } + return true; } }; diff --git a/test/ir/pir/cinn/inference/test_llama_postprocess.py b/test/ir/pir/cinn/inference/test_llama_postprocess.py index 1600a3a794409f..cfff921719f955 100644 --- a/test/ir/pir/cinn/inference/test_llama_postprocess.py +++ b/test/ir/pir/cinn/inference/test_llama_postprocess.py @@ -90,8 +90,8 @@ def prepare_data(self): self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64") def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 4) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 4}) + utils.check_jit_kernel_number(static_fn, 7) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 7}) def eval(self, use_cinn): paddle.seed(2024) From 88be07b7e0012d56fed02fc6e419ec2d48b614fa Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 24 May 2024 08:37:16 +0000 Subject: [PATCH 3/3] fix --- test/ir/pir/cinn/test_horizontal_fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ir/pir/cinn/test_horizontal_fusion.py b/test/ir/pir/cinn/test_horizontal_fusion.py index 38a073e55323a0..e681c1bb781236 100644 --- a/test/ir/pir/cinn/test_horizontal_fusion.py +++ b/test/ir/pir/cinn/test_horizontal_fusion.py @@ -39,8 +39,8 @@ def prepare_data(self): self.x.stop_gradient = True def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 1) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + utils.check_jit_kernel_number(static_fn, 2) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2}) def eval(self, use_cinn): net = HorizontalSubGraph()