2323#include " paddle/fluid/pir/dialect/operator/ir/op_type.h"
2424#include " paddle/fluid/pir/dialect/operator/ir/pd_op.h"
2525#include " paddle/pir/include/core/builtin_dialect.h"
26+ #include " paddle/pir/include/core/ir_mapping.h"
2627namespace cinn ::dialect::details {
2728
2829pir::Attribute ArrayAttributeToIntArrayAttribute (
@@ -36,37 +37,147 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
3637 return attr_data;
3738}
3839
40+ const auto & handler_reduce_sum_op =
41+ [](::pir::Operation* op,
42+ ::pir::IrMapping& ir_mapping, // NOLINT
43+ ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
44+ VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
45+ auto attrs = op->attributes ();
46+
47+ pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute (
48+ attrs.at (" dim" ).dyn_cast <::pir::ArrayAttribute>());
49+ attrs.insert ({" axis" , attr_axis});
50+ attrs.insert ({" dtype" , attrs[" dtype" ]});
51+ attrs.insert ({" keepdim" , attrs[" keep_dim" ]});
52+ attrs.erase (" dim" );
53+ attrs.erase (" keep_dim" );
54+
55+ auto pd_op = builder.Build <paddle::dialect::SumOp>(
56+ ir_mapping.Lookup (op->operand_source (0 )), attrs);
57+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
58+ ir_mapping.Add (op->result (i), pd_op->result (i));
59+ }
60+ return pd_op;
61+ };
62+
3963const auto & handler_reduce_max_op =
40- [&](::pir::Operation* op,
41- const ::pir::Builder& builder) -> ::pir::Operation* {
64+ [](::pir::Operation* op,
65+ ::pir::IrMapping& ir_mapping, // NOLINT
66+ ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
4267 VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
43- auto cinn_op = op->dyn_cast <cinn::dialect::ReduceMaxOp>();
44- auto attr = cinn_op.attributes ();
68+ auto attrs = op->attributes ();
4569
4670 // TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute
4771 // Normalization
4872 pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute (
49- attr.at (" dim" ).dyn_cast <::pir::ArrayAttribute>());
50- attr.insert ({" axis" , attr_axis});
51- attr.insert ({" keepdim" , attr[" keep_dim" ]});
52- attr.erase (" dim" );
53- attr.erase (" keep_dim" );
54-
55- auto pd_op =
56- const_cast <::pir::Builder*>(&builder)->Build <paddle::dialect::MaxOp>(
57- cinn_op.operand_source (0 ), attr);
73+ attrs.at (" dim" ).dyn_cast <::pir::ArrayAttribute>());
74+ attrs.insert ({" axis" , attr_axis});
75+ attrs.insert ({" keepdim" , attrs[" keep_dim" ]});
76+ attrs.erase (" dim" );
77+ attrs.erase (" keep_dim" );
78+
79+ auto pd_op = builder.Build <paddle::dialect::MaxOp>(
80+ ir_mapping.Lookup (op->operand_source (0 )), attrs);
81+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
82+ ir_mapping.Add (op->result (i), pd_op->result (i));
83+ }
84+ return pd_op;
85+ };
86+
87+ const auto & handler_reduce_min_op =
88+ [](::pir::Operation* op,
89+ ::pir::IrMapping& ir_mapping, // NOLINT
90+ ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
91+ VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
92+ auto attrs = op->attributes ();
93+
94+ pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute (
95+ attrs.at (" dim" ).dyn_cast <::pir::ArrayAttribute>());
96+ attrs.insert ({" axis" , attr_axis});
97+ attrs.insert ({" keepdim" , attrs[" keep_dim" ]});
98+ attrs.erase (" dim" );
99+ attrs.erase (" keep_dim" );
100+
101+ auto pd_op = builder.Build <paddle::dialect::MinOp>(
102+ ir_mapping.Lookup (op->operand_source (0 )), attrs);
103+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
104+ ir_mapping.Add (op->result (i), pd_op->result (i));
105+ }
106+ return pd_op;
107+ };
108+
109+ const auto & handler_reduce_prod_op =
110+ [](::pir::Operation* op,
111+ ::pir::IrMapping& ir_mapping, // NOLINT
112+ ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT
113+ VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
114+ auto attrs = op->attributes ();
115+
116+ pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute (
117+ attrs.at (" dim" ).dyn_cast <::pir::ArrayAttribute>());
118+ attrs.insert ({" dims" , attr_axis});
119+ attrs.erase (" dim" );
120+
121+ auto pd_op = builder.Build <paddle::dialect::ProdOp>(
122+ ir_mapping.Lookup (op->operand_source (0 )), attrs);
123+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
124+ ir_mapping.Add (op->result (i), pd_op->result (i));
125+ }
58126 return pd_op;
59127};
60128
129+ ::pir::Operation* ConvertSliceOp (::pir::Operation* op,
130+ ::pir::IrMapping& ir_mapping, // NOLINT
131+ ::pir::Builder& builder) { // NOLINT
132+ VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
133+ auto attrs = op->attributes ();
134+ pir::Attribute starts = ArrayAttributeToIntArrayAttribute (
135+ attrs.at (" starts" ).dyn_cast <::pir::ArrayAttribute>());
136+ pir::Attribute ends = ArrayAttributeToIntArrayAttribute (
137+ attrs.at (" ends" ).dyn_cast <::pir::ArrayAttribute>());
138+ attrs[" starts" ] = starts;
139+ attrs[" ends" ] = ends;
140+ auto pd_op = builder.Build <paddle::dialect::SliceOp>(
141+ ir_mapping.Lookup (op->operand_source (0 )), attrs);
142+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
143+ ir_mapping.Add (op->result (i), pd_op->result (i));
144+ }
145+ return pd_op;
146+ }
147+
148+ ::pir::Operation* ConvertConcatOp (::pir::Operation* op,
149+ ::pir::IrMapping& ir_mapping, // NOLINT
150+ ::pir::Builder& builder) { // NOLINT
151+ VLOG (6 ) << " transform " << op->name () << " from cinn_op to pd_op" ;
152+ auto attrs = op->attributes ();
153+ for (auto item : attrs) {
154+ VLOG (0 ) << item.first ;
155+ }
156+ std::vector<pir::Value> vec_inputs;
157+ for (uint32_t i = 0 ; i < op->num_operands (); ++i) {
158+ vec_inputs.push_back (ir_mapping.Lookup (op->operand_source (i)));
159+ }
160+ auto op_input = builder.Build <pir::CombineOp>(vec_inputs).result (0 );
161+
162+ int axis = attrs.at (" axis" ).dyn_cast <::pir::Int32Attribute>().data ();
163+
164+ auto pd_op = builder.Build <paddle::dialect::ConcatOp>(op_input, axis);
165+ for (uint32_t i = 0 ; i < op->num_results (); ++i) {
166+ ir_mapping.Add (op->result (i), pd_op->result (i));
167+ }
168+ return pd_op;
169+ }
170+
61171bool CanApplyOn (::pir::Operation* op) {
62172 return op->dialect ()->name () == " cinn_op" ;
63173}
64174
65175::pir::Operation* RewriteCinnOpToPdOp (::pir::Operation* op,
66- const ::pir::Builder& builder) {
176+ ::pir::IrMapping& ir_mapping, // NOLINT
177+ ::pir::Builder& builder) { // NOLINT
67178 VLOG (8 ) << " Rewrite CinnOp to PdOp for op: " << op->name ();
68179 auto & op_transformers = TransformContext::Instance ();
69- return op_transformers[op->name ()](op, builder);
180+ return op_transformers[op->name ()](op, ir_mapping, builder);
70181}
71182
72183void RewriteCinnOpToPdOp (const ::pir::Block& src_block,
@@ -91,20 +202,37 @@ void RewriteCinnOpToPdOp(const ::pir::Block& src_block,
91202 }
92203 ::pir::Operation* new_op;
93204 if (CanApplyOn (&op)) {
94- new_op = RewriteCinnOpToPdOp (&op, builder);
205+ new_op = RewriteCinnOpToPdOp (&op, ir_mapping, builder);
95206 new_op->MoveTo (target_block, target_block->end ());
96207 } else {
97208 new_op = op.Clone (ir_mapping, clone_options);
98209 new_op->MoveTo (target_block, target_block->end ());
99210 }
100- for (uint32_t i = 0 ; i < op.num_results (); ++i) {
101- ir_mapping.Add (op.result (i), new_op->result (i));
102- }
103211 }
104212}
105213
106214} // namespace cinn::dialect::details
107215
216+ REGISTER_TRANSFORM_RULES (reduce_sum_op,
217+ cinn::dialect::ReduceSumOp::name (),
218+ cinn::dialect::details::handler_reduce_sum_op);
219+
108220REGISTER_TRANSFORM_RULES (reduce_max_op,
109221 cinn::dialect::ReduceMaxOp::name (),
110222 cinn::dialect::details::handler_reduce_max_op);
223+
224+ REGISTER_TRANSFORM_RULES (reduce_min_op,
225+ cinn::dialect::ReduceMinOp::name (),
226+ cinn::dialect::details::handler_reduce_min_op);
227+
228+ REGISTER_TRANSFORM_RULES (reduce_prod_op,
229+ cinn::dialect::ReduceProdOp::name (),
230+ cinn::dialect::details::handler_reduce_prod_op);
231+
232+ REGISTER_TRANSFORM_RULES (slice_op,
233+ cinn::dialect::SliceOp::name (),
234+ cinn::dialect::details::ConvertSliceOp);
235+
236+ REGISTER_TRANSFORM_RULES (concat_op,
237+ cinn::dialect::ConcatOp::name (),
238+ cinn::dialect::details::ConvertConcatOp);
0 commit comments