3434
3535#include " paddle/pir/core/builtin_attribute.h"
3636#include " paddle/pir/core/builtin_op.h"
37+ #include " paddle/pir/core/builtin_type.h"
3738#include " paddle/pir/core/ir_context.h"
3839#include " paddle/pir/core/op_result.h"
3940#include " paddle/pir/core/op_trait.h"
@@ -83,17 +84,36 @@ class ConstantFoldingPattern : public pir::RewritePattern {
8384 if (!op->operand_source (i) || !op->operand_source (i).type ()) {
8485 continue ;
8586 }
86- // 2. inputs must come from parameter op or constant op
87+ // 2. inputs must come from ParameterOp/ConstantTensorOp/CombineOp
8788 auto * prev_op = pir::GetDefiningOpForInput (op, i);
8889 if (!prev_op || !(prev_op->isa <pir::ParameterOp>() ||
89- prev_op->isa <pir::ConstantTensorOp>())) {
90+ prev_op->isa <pir::ConstantTensorOp>() ||
91+ prev_op->isa <pir::CombineOp>())) {
9092 return false ;
9193 }
92- // 3. inputs must be a dense tensor type
93- if (!op->operand_source (i)
94- .type ()
95- .isa <paddle::dialect::DenseTensorType>()) {
96- return false ;
94+ if (prev_op->isa <pir::CombineOp>()) {
95+ if (prev_op->result (0 ).use_count () > 1 ) {
96+ return false ;
97+ }
98+ for (uint32_t i = 0 ; i < prev_op->num_operands (); i++) {
99+ if (!prev_op->operand_source (i) ||
100+ !prev_op->operand_source (i).type ()) {
101+ continue ;
102+ }
103+ if (!prev_op->operand_source (i)
104+ .type ()
105+ .isa <paddle::dialect::DenseTensorType>()) {
106+ return false ;
107+ }
108+ }
109+
110+ } else {
111+ // 3. inputs must be a dense tensor type
112+ if (!op->operand_source (i)
113+ .type ()
114+ .isa <paddle::dialect::DenseTensorType>()) {
115+ return false ;
116+ }
97117 }
98118 }
99119
@@ -233,7 +253,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
233253 BuildProgramFromOperation (op, &new_program, rewriter);
234254
235255 // execute program
236- for (auto output_var_name : output_var_names) {
256+ for (const auto & output_var_name : output_var_names) {
237257 exe_config_->skip_gc_vars .insert (output_var_name);
238258 }
239259 auto kernel_program =
@@ -256,42 +276,70 @@ class ConstantFoldingPattern : public pir::RewritePattern {
256276 std::vector<pir::Value> op_inputs;
257277 for (uint32_t i = 0 ; i < op->num_operands (); i++) {
258278 if (op->operand_source (i)) {
259- const auto & param_name =
260- pir::GetParameterNameFromValue (op->operand_source (i));
261- auto * param_var = scope_->FindVar (param_name);
262- PADDLE_ENFORCE_NOT_NULL (
263- param_var,
264- phi::errors::InvalidArgument (" Parameter var [%s] not in scope." ,
265- param_name));
266-
267- auto parameter_op = builder.Build <pir::ParameterOp>(
268- param_name, op->operand_source (i).type ());
269- if (op->operand_source (i).use_count () <= 1 ) {
270- deleted_vars_->push_back (param_name);
279+ auto * prev_op = pir::GetDefiningOpForInput (op, i);
280+ if (prev_op->isa <pir::CombineOp>()) {
281+ // prepare combine op inputs
282+ std::vector<pir::Value> combine_op_inputs;
283+ for (uint32_t i = 0 ; i < prev_op->num_operands (); i++) {
284+ const auto & param_name =
285+ pir::GetParameterNameFromValue (prev_op->operand_source (i));
286+ auto * param_var = scope_->FindVar (param_name);
287+ PADDLE_ENFORCE_NOT_NULL (
288+ param_var,
289+ phi::errors::InvalidArgument (" Parameter var [%s] not in scope." ,
290+ param_name));
291+
292+ auto parameter_op = builder.Build <pir::ParameterOp>(
293+ param_name, prev_op->operand_source (i).type ());
294+ if (prev_op->operand_source (i).use_count () <= 1 ) {
295+ deleted_vars_->push_back (param_name);
296+ } else {
297+ parameter_op->set_attribute (
298+ kAttrIsPersisable ,
299+ rewriter.array_attr ({rewriter.bool_attr (true )}));
300+ }
301+ combine_op_inputs.push_back (parameter_op->result (0 ));
302+ }
303+ auto combine_op = builder.Build <pir::CombineOp>(combine_op_inputs);
304+ op_inputs.push_back (combine_op->result (0 ));
271305 } else {
272- parameter_op->set_attribute (
273- kAttrIsPersisable ,
274- rewriter.array_attr ({rewriter.bool_attr (true )}));
306+ const auto & param_name =
307+ pir::GetParameterNameFromValue (op->operand_source (i));
308+ auto * param_var = scope_->FindVar (param_name);
309+ PADDLE_ENFORCE_NOT_NULL (
310+ param_var,
311+ phi::errors::InvalidArgument (" Parameter var [%s] not in scope." ,
312+ param_name));
313+
314+ auto parameter_op = builder.Build <pir::ParameterOp>(
315+ param_name, op->operand_source (i).type ());
316+ if (op->operand_source (i).use_count () <= 1 ) {
317+ deleted_vars_->push_back (param_name);
318+ } else {
319+ parameter_op->set_attribute (
320+ kAttrIsPersisable ,
321+ rewriter.array_attr ({rewriter.bool_attr (true )}));
322+ }
323+ op_inputs.push_back (parameter_op->result (0 ));
275324 }
276- op_inputs.push_back (parameter_op->result (0 ));
277325 } else {
278326 op_inputs.push_back (
279327 op->operand_source (i).dyn_cast <pir::OpResult>() /* nullptr*/ );
280328 }
281329 }
282330
283331 // prepare op outputs
284- std::vector<pir::Type> output_types ;
332+ std::vector<pir::Type> op_output_types ;
285333 for (uint32_t i = 0 ; i < op->num_results (); i++) {
286- output_types .push_back (op->result (i).type ());
334+ op_output_types .push_back (op->result (i).type ());
287335 }
288336
289- auto * temp_op =
290- builder.Build (op_inputs, op->attributes (), output_types , op->info ());
337+ auto * op_copy =
338+ builder.Build (op_inputs, op->attributes (), op_output_types , op->info ());
291339
292340 std::vector<std::string> output_var_names;
293- for (uint32_t i = 0 ; i < op ->num_results (); i++) {
294- if (!temp_op ->result (i) || !temp_op ->result (i).type ()) {
341+ for (uint32_t i = 0 ; i < op_copy ->num_results (); i++) {
342+ if (!op_copy ->result (i) || !op_copy ->result (i).type ()) {
295343 continue ;
296344 }
297345 std::stringstream ss;
@@ -301,7 +349,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
301349 std::string output_var_name =
302350 " constant_folding@_" + ss.str () + std::to_string ((*suffix_)++);
303351
304- builder.Build <pir::ShadowOutputOp>(temp_op ->result (i), output_var_name);
352+ builder.Build <pir::ShadowOutputOp>(op_copy ->result (i), output_var_name);
305353 output_var_names.push_back (output_var_name);
306354 }
307355
0 commit comments