From 495fa9115ddc09b31bea45ecec28dc069210b372 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 22 Feb 2024 18:00:32 +0800 Subject: [PATCH] [PIR+CINN]Fix Convert0DTo1D Pass bug in CombineOp --- .../group_merge/convert_0d_to_1d_pass.cc | 35 +++++++++++++++++++ .../pir/cinn/sub_graphs/test_sub_graph_23.py | 8 ++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc index f60878a9e1d99d..325421d92abe67 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc @@ -58,6 +58,40 @@ class FullOpPattern : public pir::OpRewritePattern { } }; +class CombineOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool Match(pir::CombineOp op) const override { + auto out_type = op.result(0).type().dyn_cast(); + for (auto type : out_type.data()) { + if (HasZeroDim(type)) return true; + } + return false; + } + + void Rewrite(pir::CombineOp op, + pir::PatternRewriter &rewriter) const override { + pir::Builder builder(rewriter.ir_context()); + + const std::vector inputs_type = [&]() { + std::vector types; + for (auto value : op->operands_source()) { + types.push_back(value.type()); + } + return types; + }(); + op.result(0).set_type(builder.vec_type(inputs_type)); + } + + private: + bool HasZeroDim(pir::Type type) const { + if (!type) return false; + const auto dense_tensor_type = type.dyn_cast(); + return dense_tensor_type && (dense_tensor_type.dims().size() == 0U); + } +}; + class Convert0DTo1DPass : public pir::PatternRewritePass { public: Convert0DTo1DPass() : pir::PatternRewritePass("convert_0D_to_1D", 1) {} @@ -65,6 +99,7 @@ class Convert0DTo1DPass : public pir::PatternRewritePass { pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); ps.Add(context); + ps.Add(context); return ps; } diff --git a/test/ir/pir/cinn/sub_graphs/test_sub_graph_23.py b/test/ir/pir/cinn/sub_graphs/test_sub_graph_23.py index 0d140fda014841..5f04f7b0f9bd2f 100644 --- a/test/ir/pir/cinn/sub_graphs/test_sub_graph_23.py +++ b/test/ir/pir/cinn/sub_graphs/test_sub_graph_23.py @@ -31,8 +31,7 @@ def forward( var_0, # (shape: [11, 24, 56, 56], dtype: paddle.float32, stop_gradient: False) var_1, # (shape: [11, 24, 56, 56], dtype: paddle.float32, stop_gradient: False) ): - var_2 = paddle.tensor.attribute.shape(var_0) - var_3 = var_2[0] + var_3 = var_0.shape[0] var_4 = paddle.tensor.random.rand(shape=[var_3, 1, 1, 1]) var_5 = 0.975 + var_4 var_6 = paddle.tensor.ops.floor(var_5) @@ -65,16 +64,15 @@ def train(self, net, to_static, with_prim=False, with_cinn=False): outs = net(*self.inputs) return outs - # NOTE prim + cinn lead to error def test_ast_prim_cinn(self): st_out = self.train(self.net, to_static=True) cinn_out = self.train( - self.net, to_static=True, with_prim=True, with_cinn=False + self.net, to_static=True, with_prim=True, with_cinn=True ) for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out) ): - np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8) + np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6) if __name__ == '__main__':