From 823299001157bb751ba5a1f1ee69d1288928d3f5 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sun, 18 Feb 2024 11:38:19 +0000 Subject: [PATCH] [PIR+CINN]Fix cinn_op.pool2d Attribute and open subgraph0 UT --- .../cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc | 2 -- test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py | 5 ++++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 83c334cf285e77..56092ebfe50c65 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -273,10 +273,8 @@ class Pool2dOpPattern pir::ArrayAttribute::get(pir::IrContext::Instance(), kernel_size); attrs["stride_size"] = attrs.at("strides"); attrs["padding_size"] = attrs.at("paddings"); - attrs["pool_type"] = attrs.at("pooling_type"); attrs.erase("strides"); attrs.erase("paddings"); - attrs.erase("pooling_type"); auto cinn_reshape = rewriter.Build(op->operand_source(0), attrs); diff --git a/test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py b/test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py index 73d552d485ab59..b8652111d9a10a 100644 --- a/test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py +++ b/test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py @@ -96,8 +96,11 @@ def train(self, net, to_static, with_prim=False, with_cinn=False): # NOTE prim + cinn lead to error def test_ast_prim_cinn(self): st_out = self.train(self.net, to_static=True) + # NOTE(Aurelius84): cinn_op.pool2d only support pool_type='avg' under adaptive=True + paddle.set_flags({"FLAGS_deny_cinn_ops": "pool2d"}) + # TODO(Aurelius84): Fix LoopAligment eror under with_prim=True cinn_out = self.train( - self.net, to_static=True, with_prim=True, with_cinn=False + self.net, to_static=True, with_prim=False, with_cinn=True ) for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)