Skip to content

Commit 370375f

Browse files
chen2016013zyfncg
authored andcommitted
[CINN] Add rules for cinn_to_pd (PaddlePaddle#64152)
* add shape_ops_fallback_to_phi_pass * fix test * add ops for cinn_to_pd * Update accuracy_check_pass.cc * fix test_sub_graph_90 bug * update --------- Co-authored-by: zyfncg <[email protected]>
1 parent 6e82036 commit 370375f

File tree

11 files changed

+306
-49
lines changed

11 files changed

+306
-49
lines changed

paddle/cinn/hlir/dialect/operator/ir/manual_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT
268268
pir::Value input,
269269
const std::vector<int>& sections,
270270
int axis) {
271-
VLOG(4) << "Start build ConcatOp";
271+
VLOG(4) << "Start build SplitOp";
272272

273273
argument.inputs.push_back(input);
274274

paddle/cinn/hlir/dialect/operator/ir/ops.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,25 @@
5858
interfaces : paddle::dialect::InferSymbolicShapeInterface
5959

6060
- op : reduce_prod
61-
args : (Tensor x, int64_t[] dim, bool keep_dim)
61+
args : (Tensor x, int64_t[] dim, bool keep_dim, bool reduce_all)
6262
output : Tensor(out)
6363
infer_meta :
6464
func : ReduceInferMeta
65+
param : [x, dim, keep_dim]
6566
kernel :
6667
func : frobenius_norm
68+
param : [x, dim, keep_dim]
6769
interfaces : paddle::dialect::InferSymbolicShapeInterface
6870

6971
- op : reduce_sum
70-
args : (Tensor x, int64_t[] dim, bool keep_dim)
72+
args : (Tensor x, int64_t[] dim, bool keep_dim, DataType dtype=DataType::UNDEFINED)
7173
output : Tensor(out)
7274
infer_meta :
7375
func : ReduceInferMeta
76+
param : [x, dim, keep_dim]
7477
kernel :
7578
func : frobenius_norm
79+
param : [x, dim, keep_dim]
7680
interfaces : paddle::dialect::InferSymbolicShapeInterface
7781

7882
- op : reshape

paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,7 @@ class AddAccuracyCheckPattern
7575
}
7676
}
7777
pir::Operation* pd_op =
78-
cinn::dialect::details::RewriteCinnOpToPdOp(op, builder);
79-
for (uint32_t i = 0; i < op->num_results(); ++i) {
80-
ir_mapping.Add(op->result(i), pd_op->result(i));
81-
}
78+
cinn::dialect::details::RewriteCinnOpToPdOp(op, ir_mapping, builder);
8279
};
8380

8481
const auto& ClonePdOp = [&](::pir::Operation* op) -> void {
@@ -106,7 +103,7 @@ class AddAccuracyCheckPattern
106103

107104
class AccuarcyCheckPass : public pir::Pass {
108105
public:
109-
AccuarcyCheckPass() : pir::Pass("accuarcy_check_pass", /*opt_level=*/4) {}
106+
AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {}
110107

111108
bool Initialize(pir::IrContext* context) override {
112109
pir::RewritePatternSet ps(context);

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
4646
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
4747
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
48+
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
4849
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
4950
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
5051
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
@@ -156,6 +157,10 @@ void ApplyDivideGroupOpToFusionOpPass(
156157
pass_manager->AddPass(
157158
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
158159
}
160+
161+
pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
162+
pass_manager->AddPass(cinn::dialect::ir::CreateShapeOpsFallbackToPhiPass());
163+
159164
pass_manager->Run(program);
160165
}
161166

@@ -176,7 +181,6 @@ void ApplyCinnLowerPass(
176181
pass_manager->AddPass(std::move(pass.value()));
177182
}
178183

179-
pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
180184
if (FLAGS_enable_cinn_accuracy_check) {
181185
VLOG(0) << "Enable CINN Accuracy Check Pass";
182186
pass_manager->AddPass(cinn::dialect::ir::CreateAccuarcyCheckPass());

paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc

Lines changed: 147 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
2424
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
2525
#include "paddle/pir/include/core/builtin_dialect.h"
26+
#include "paddle/pir/include/core/ir_mapping.h"
2627
namespace cinn::dialect::details {
2728

2829
pir::Attribute ArrayAttributeToIntArrayAttribute(
@@ -36,37 +37,147 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
3637
return attr_data;
3738
}
3839

40+
const auto& handler_reduce_sum_op =
41+
[](::pir::Operation* op,
42+
::pir::IrMapping& ir_mapping, // NOLINT
43+
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
44+
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
45+
auto attrs = op->attributes();
46+
47+
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
48+
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
49+
attrs.insert({"axis", attr_axis});
50+
attrs.insert({"dtype", attrs["dtype"]});
51+
attrs.insert({"keepdim", attrs["keep_dim"]});
52+
attrs.erase("dim");
53+
attrs.erase("keep_dim");
54+
55+
auto pd_op = builder.Build<paddle::dialect::SumOp>(
56+
ir_mapping.Lookup(op->operand_source(0)), attrs);
57+
for (uint32_t i = 0; i < op->num_results(); ++i) {
58+
ir_mapping.Add(op->result(i), pd_op->result(i));
59+
}
60+
return pd_op;
61+
};
62+
3963
const auto& handler_reduce_max_op =
40-
[&](::pir::Operation* op,
41-
const ::pir::Builder& builder) -> ::pir::Operation* {
64+
[](::pir::Operation* op,
65+
::pir::IrMapping& ir_mapping, // NOLINT
66+
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
4267
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
43-
auto cinn_op = op->dyn_cast<cinn::dialect::ReduceMaxOp>();
44-
auto attr = cinn_op.attributes();
68+
auto attrs = op->attributes();
4569

4670
// TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute
4771
// Normalization
4872
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
49-
attr.at("dim").dyn_cast<::pir::ArrayAttribute>());
50-
attr.insert({"axis", attr_axis});
51-
attr.insert({"keepdim", attr["keep_dim"]});
52-
attr.erase("dim");
53-
attr.erase("keep_dim");
54-
55-
auto pd_op =
56-
const_cast<::pir::Builder*>(&builder)->Build<paddle::dialect::MaxOp>(
57-
cinn_op.operand_source(0), attr);
73+
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
74+
attrs.insert({"axis", attr_axis});
75+
attrs.insert({"keepdim", attrs["keep_dim"]});
76+
attrs.erase("dim");
77+
attrs.erase("keep_dim");
78+
79+
auto pd_op = builder.Build<paddle::dialect::MaxOp>(
80+
ir_mapping.Lookup(op->operand_source(0)), attrs);
81+
for (uint32_t i = 0; i < op->num_results(); ++i) {
82+
ir_mapping.Add(op->result(i), pd_op->result(i));
83+
}
84+
return pd_op;
85+
};
86+
87+
const auto& handler_reduce_min_op =
88+
[](::pir::Operation* op,
89+
::pir::IrMapping& ir_mapping, // NOLINT
90+
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
91+
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
92+
auto attrs = op->attributes();
93+
94+
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
95+
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
96+
attrs.insert({"axis", attr_axis});
97+
attrs.insert({"keepdim", attrs["keep_dim"]});
98+
attrs.erase("dim");
99+
attrs.erase("keep_dim");
100+
101+
auto pd_op = builder.Build<paddle::dialect::MinOp>(
102+
ir_mapping.Lookup(op->operand_source(0)), attrs);
103+
for (uint32_t i = 0; i < op->num_results(); ++i) {
104+
ir_mapping.Add(op->result(i), pd_op->result(i));
105+
}
106+
return pd_op;
107+
};
108+
109+
const auto& handler_reduce_prod_op =
110+
[](::pir::Operation* op,
111+
::pir::IrMapping& ir_mapping, // NOLINT
112+
::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
113+
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
114+
auto attrs = op->attributes();
115+
116+
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
117+
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
118+
attrs.insert({"dims", attr_axis});
119+
attrs.erase("dim");
120+
121+
auto pd_op = builder.Build<paddle::dialect::ProdOp>(
122+
ir_mapping.Lookup(op->operand_source(0)), attrs);
123+
for (uint32_t i = 0; i < op->num_results(); ++i) {
124+
ir_mapping.Add(op->result(i), pd_op->result(i));
125+
}
58126
return pd_op;
59127
};
60128

129+
::pir::Operation* ConvertSliceOp(::pir::Operation* op,
130+
::pir::IrMapping& ir_mapping, // NOLINT
131+
::pir::Builder& builder) { // NOLINT
132+
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
133+
auto attrs = op->attributes();
134+
pir::Attribute starts = ArrayAttributeToIntArrayAttribute(
135+
attrs.at("starts").dyn_cast<::pir::ArrayAttribute>());
136+
pir::Attribute ends = ArrayAttributeToIntArrayAttribute(
137+
attrs.at("ends").dyn_cast<::pir::ArrayAttribute>());
138+
attrs["starts"] = starts;
139+
attrs["ends"] = ends;
140+
auto pd_op = builder.Build<paddle::dialect::SliceOp>(
141+
ir_mapping.Lookup(op->operand_source(0)), attrs);
142+
for (uint32_t i = 0; i < op->num_results(); ++i) {
143+
ir_mapping.Add(op->result(i), pd_op->result(i));
144+
}
145+
return pd_op;
146+
}
147+
148+
::pir::Operation* ConvertConcatOp(::pir::Operation* op,
149+
::pir::IrMapping& ir_mapping, // NOLINT
150+
::pir::Builder& builder) { // NOLINT
151+
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
152+
auto attrs = op->attributes();
153+
for (auto item : attrs) {
154+
VLOG(0) << item.first;
155+
}
156+
std::vector<pir::Value> vec_inputs;
157+
for (uint32_t i = 0; i < op->num_operands(); ++i) {
158+
vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i)));
159+
}
160+
auto op_input = builder.Build<pir::CombineOp>(vec_inputs).result(0);
161+
162+
int axis = attrs.at("axis").dyn_cast<::pir::Int32Attribute>().data();
163+
164+
auto pd_op = builder.Build<paddle::dialect::ConcatOp>(op_input, axis);
165+
for (uint32_t i = 0; i < op->num_results(); ++i) {
166+
ir_mapping.Add(op->result(i), pd_op->result(i));
167+
}
168+
return pd_op;
169+
}
170+
61171
bool CanApplyOn(::pir::Operation* op) {
62172
return op->dialect()->name() == "cinn_op";
63173
}
64174

65175
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation* op,
66-
const ::pir::Builder& builder) {
176+
::pir::IrMapping& ir_mapping, // NOLINT
177+
::pir::Builder& builder) { // NOLINT
67178
VLOG(8) << "Rewrite CinnOp to PdOp for op: " << op->name();
68179
auto& op_transformers = TransformContext::Instance();
69-
return op_transformers[op->name()](op, builder);
180+
return op_transformers[op->name()](op, ir_mapping, builder);
70181
}
71182

72183
void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
@@ -91,20 +202,37 @@ void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
91202
}
92203
::pir::Operation* new_op;
93204
if (CanApplyOn(&op)) {
94-
new_op = RewriteCinnOpToPdOp(&op, builder);
205+
new_op = RewriteCinnOpToPdOp(&op, ir_mapping, builder);
95206
new_op->MoveTo(target_block, target_block->end());
96207
} else {
97208
new_op = op.Clone(ir_mapping, clone_options);
98209
new_op->MoveTo(target_block, target_block->end());
99210
}
100-
for (uint32_t i = 0; i < op.num_results(); ++i) {
101-
ir_mapping.Add(op.result(i), new_op->result(i));
102-
}
103211
}
104212
}
105213

106214
} // namespace cinn::dialect::details
107215

216+
REGISTER_TRANSFORM_RULES(reduce_sum_op,
217+
cinn::dialect::ReduceSumOp::name(),
218+
cinn::dialect::details::handler_reduce_sum_op);
219+
108220
REGISTER_TRANSFORM_RULES(reduce_max_op,
109221
cinn::dialect::ReduceMaxOp::name(),
110222
cinn::dialect::details::handler_reduce_max_op);
223+
224+
REGISTER_TRANSFORM_RULES(reduce_min_op,
225+
cinn::dialect::ReduceMinOp::name(),
226+
cinn::dialect::details::handler_reduce_min_op);
227+
228+
REGISTER_TRANSFORM_RULES(reduce_prod_op,
229+
cinn::dialect::ReduceProdOp::name(),
230+
cinn::dialect::details::handler_reduce_prod_op);
231+
232+
REGISTER_TRANSFORM_RULES(slice_op,
233+
cinn::dialect::SliceOp::name(),
234+
cinn::dialect::details::ConvertSliceOp);
235+
236+
REGISTER_TRANSFORM_RULES(concat_op,
237+
cinn::dialect::ConcatOp::name(),
238+
cinn::dialect::details::ConvertConcatOp);

paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
#include "paddle/common/enforce.h"
2121

2222
namespace pir {
23+
class IrMapping;
2324
class Block;
2425
class Operation;
2526
class Builder;
27+
class IrMapping;
2628
} // namespace pir
2729

2830
namespace cinn::dialect::details {
2931

30-
using TRule =
31-
std::function<::pir::Operation*(::pir::Operation*, const ::pir::Builder&)>;
32+
using TRule = std::function<::pir::Operation*(
33+
::pir::Operation*, ::pir::IrMapping&, ::pir::Builder&)>;
3234

3335
class TransformContext {
3436
private:
@@ -86,6 +88,8 @@ class TransformRegistrar {
8688

8789
void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
8890
::pir::Block* target_block);
89-
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*, const ::pir::Builder&);
91+
::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*,
92+
::pir::IrMapping&, // NOLINT
93+
::pir::Builder&); // NOLINT
9094

9195
} // namespace cinn::dialect::details

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class SumOpPattern : public paddle::drr::DrrPatternBase {
5656
const auto &cinn_reduce_sum =
5757
res.Op(cinn::dialect::ReduceSumOp::name(),
5858
{{"dim", pattern.Attr("axis_info")},
59+
{"dtype", pattern.Attr("dtype")},
5960
{"keep_dim", pattern.Attr("keep_dim")}});
6061
res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0"));
6162
}
@@ -128,16 +129,19 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase {
128129
{"dtype", pattern.Attr("dtype_2")},
129130
{"place", pattern.Attr("place_2")}});
130131

131-
const auto &pd_max = pattern.Op(paddle::dialect::ProdOp::name(),
132-
{{"keep_dim", pattern.Attr("keep_dim")}});
132+
const auto &pd_max =
133+
pattern.Op(paddle::dialect::ProdOp::name(),
134+
{{"keep_dim", pattern.Attr("keep_dim")},
135+
{"reduce_all", pattern.Attr("reduce_all")}});
133136
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());
134137

135138
// Result patterns
136139
paddle::drr::ResultPattern res = pattern.ResultPattern();
137140
const auto &cinn_reduce_max =
138141
res.Op(cinn::dialect::ReduceProdOp::name(),
139142
{{"dim", pattern.Attr("axis_info")},
140-
{"keep_dim", pattern.Attr("keep_dim")}});
143+
{"keep_dim", pattern.Attr("keep_dim")},
144+
{"reduce_all", pattern.Attr("reduce_all")}});
141145
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
142146
}
143147
};

0 commit comments

Comments
 (0)