Skip to content

Commit f4a62d1

Browse files
committed
pir support attention_fuse_pass to fuse a multihead_matmul op
1 parent e6e8b60 commit f4a62d1

6 files changed

Lines changed: 433 additions & 70 deletions

File tree

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
#include "paddle/fluid/ir_adaptor/translator/translate.h"
108108
#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
109109
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
110+
#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h"
110111
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h"
111112
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h"
112113
#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h"
@@ -804,6 +805,7 @@ bool AnalysisPredictor::PrepareExecutor() {
804805
gpu_pm.AddPass(::pir::CreateConv2dBnFusePass());
805806
gpu_pm.AddPass(::pir::CreateConv2dAddActFusePass());
806807
gpu_pm.AddPass(::pir::CreateConv2dAddFusePass());
808+
gpu_pm.AddPass(::pir::CreateAttentionFusePass());
807809
gpu_pm.AddPass(::pir::CreateFcFusePass());
808810
gpu_pm.AddPass(::pir::CreateFcElementwiseLayerNormFusePass());
809811
gpu_pm.AddPass(::pir::CreateMatmulScaleFusePass());

paddle/fluid/pir/transforms/constant_folding_pass.cc

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include "paddle/pir/core/builtin_attribute.h"
3636
#include "paddle/pir/core/builtin_op.h"
37+
#include "paddle/pir/core/builtin_type.h"
3738
#include "paddle/pir/core/ir_context.h"
3839
#include "paddle/pir/core/op_result.h"
3940
#include "paddle/pir/core/op_trait.h"
@@ -83,17 +84,36 @@ class ConstantFoldingPattern : public pir::RewritePattern {
8384
if (!op->operand_source(i) || !op->operand_source(i).type()) {
8485
continue;
8586
}
86-
// 2. inputs must come from parameter op or constant op
87+
// 2. inputs must come from ParameterOp/ConstantTensorOp/CombineOp
8788
auto* prev_op = pir::GetDefiningOpForInput(op, i);
8889
if (!prev_op || !(prev_op->isa<pir::ParameterOp>() ||
89-
prev_op->isa<pir::ConstantTensorOp>())) {
90+
prev_op->isa<pir::ConstantTensorOp>() ||
91+
prev_op->isa<pir::CombineOp>())) {
9092
return false;
9193
}
92-
// 3. inputs must be a dense tensor type
93-
if (!op->operand_source(i)
94-
.type()
95-
.isa<paddle::dialect::DenseTensorType>()) {
96-
return false;
94+
if (prev_op->isa<pir::CombineOp>()) {
95+
if (prev_op->result(0).use_count() > 1) {
96+
return false;
97+
}
98+
for (uint32_t i = 0; i < prev_op->num_operands(); i++) {
99+
if (!prev_op->operand_source(i) ||
100+
!prev_op->operand_source(i).type()) {
101+
continue;
102+
}
103+
if (!prev_op->operand_source(i)
104+
.type()
105+
.isa<paddle::dialect::DenseTensorType>()) {
106+
return false;
107+
}
108+
}
109+
110+
} else {
111+
// 3. inputs must be a dense tensor type
112+
if (!op->operand_source(i)
113+
.type()
114+
.isa<paddle::dialect::DenseTensorType>()) {
115+
return false;
116+
}
97117
}
98118
}
99119

@@ -233,7 +253,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
233253
BuildProgramFromOperation(op, &new_program, rewriter);
234254

235255
// execute program
236-
for (auto output_var_name : output_var_names) {
256+
for (const auto& output_var_name : output_var_names) {
237257
exe_config_->skip_gc_vars.insert(output_var_name);
238258
}
239259
auto kernel_program =
@@ -256,42 +276,70 @@ class ConstantFoldingPattern : public pir::RewritePattern {
256276
std::vector<pir::Value> op_inputs;
257277
for (uint32_t i = 0; i < op->num_operands(); i++) {
258278
if (op->operand_source(i)) {
259-
const auto& param_name =
260-
pir::GetParameterNameFromValue(op->operand_source(i));
261-
auto* param_var = scope_->FindVar(param_name);
262-
PADDLE_ENFORCE_NOT_NULL(
263-
param_var,
264-
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
265-
param_name));
266-
267-
auto parameter_op = builder.Build<pir::ParameterOp>(
268-
param_name, op->operand_source(i).type());
269-
if (op->operand_source(i).use_count() <= 1) {
270-
deleted_vars_->push_back(param_name);
279+
auto* prev_op = pir::GetDefiningOpForInput(op, i);
280+
if (prev_op->isa<pir::CombineOp>()) {
281+
// prepare combine op inputs
282+
std::vector<pir::Value> combine_op_inputs;
283+
for (uint32_t i = 0; i < prev_op->num_operands(); i++) {
284+
const auto& param_name =
285+
pir::GetParameterNameFromValue(prev_op->operand_source(i));
286+
auto* param_var = scope_->FindVar(param_name);
287+
PADDLE_ENFORCE_NOT_NULL(
288+
param_var,
289+
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
290+
param_name));
291+
292+
auto parameter_op = builder.Build<pir::ParameterOp>(
293+
param_name, prev_op->operand_source(i).type());
294+
if (prev_op->operand_source(i).use_count() <= 1) {
295+
deleted_vars_->push_back(param_name);
296+
} else {
297+
parameter_op->set_attribute(
298+
kAttrIsPersisable,
299+
rewriter.array_attr({rewriter.bool_attr(true)}));
300+
}
301+
combine_op_inputs.push_back(parameter_op->result(0));
302+
}
303+
auto combine_op = builder.Build<pir::CombineOp>(combine_op_inputs);
304+
op_inputs.push_back(combine_op->result(0));
271305
} else {
272-
parameter_op->set_attribute(
273-
kAttrIsPersisable,
274-
rewriter.array_attr({rewriter.bool_attr(true)}));
306+
const auto& param_name =
307+
pir::GetParameterNameFromValue(op->operand_source(i));
308+
auto* param_var = scope_->FindVar(param_name);
309+
PADDLE_ENFORCE_NOT_NULL(
310+
param_var,
311+
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
312+
param_name));
313+
314+
auto parameter_op = builder.Build<pir::ParameterOp>(
315+
param_name, op->operand_source(i).type());
316+
if (op->operand_source(i).use_count() <= 1) {
317+
deleted_vars_->push_back(param_name);
318+
} else {
319+
parameter_op->set_attribute(
320+
kAttrIsPersisable,
321+
rewriter.array_attr({rewriter.bool_attr(true)}));
322+
}
323+
op_inputs.push_back(parameter_op->result(0));
275324
}
276-
op_inputs.push_back(parameter_op->result(0));
277325
} else {
278326
op_inputs.push_back(
279327
op->operand_source(i).dyn_cast<pir::OpResult>() /*nullptr*/);
280328
}
281329
}
282330

283331
// prepare op outputs
284-
std::vector<pir::Type> output_types;
332+
std::vector<pir::Type> op_output_types;
285333
for (uint32_t i = 0; i < op->num_results(); i++) {
286-
output_types.push_back(op->result(i).type());
334+
op_output_types.push_back(op->result(i).type());
287335
}
288336

289-
auto* temp_op =
290-
builder.Build(op_inputs, op->attributes(), output_types, op->info());
337+
auto* op_copy =
338+
builder.Build(op_inputs, op->attributes(), op_output_types, op->info());
291339

292340
std::vector<std::string> output_var_names;
293-
for (uint32_t i = 0; i < op->num_results(); i++) {
294-
if (!temp_op->result(i) || !temp_op->result(i).type()) {
341+
for (uint32_t i = 0; i < op_copy->num_results(); i++) {
342+
if (!op_copy->result(i) || !op_copy->result(i).type()) {
295343
continue;
296344
}
297345
std::stringstream ss;
@@ -301,7 +349,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
301349
std::string output_var_name =
302350
"constant_folding@_" + ss.str() + std::to_string((*suffix_)++);
303351

304-
builder.Build<pir::ShadowOutputOp>(temp_op->result(i), output_var_name);
352+
builder.Build<pir::ShadowOutputOp>(op_copy->result(i), output_var_name);
305353
output_var_names.push_back(output_var_name);
306354
}
307355

0 commit comments

Comments
 (0)