Skip to content

Commit f445bd8

Browse files
authored
[DRR]Fix SegmentFault for BlockArgument while applying pass in Llama2 infer (#62283)
* [DRR]Fix SegmentFault for BlockArgument while applying pass in Llama2 infer * fix typo
1 parent cbe8810 commit f445bd8

File tree

1 file changed

+85
-52
lines changed

1 file changed

+85
-52
lines changed

paddle/fluid/pir/drr/src/rewrite_pattern.cc

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -258,95 +258,128 @@ bool DrrRewritePattern::MatchFromOutputToInput(
258258
std::unordered_set<pir::Operation*> ir_visited;
259259
std::queue<const OpCall*> drr_q;
260260
std::queue<pir::Operation*> ir_q;
261-
bool matched = true;
262-
size_t step = 0;
263-
for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) {
264-
VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @"
265-
<< it->second << ") in source_pattern_graph ";
266-
drr_q.push(it->first);
267-
drr_visited.insert(it->first);
268-
ir_q.push(it->second);
269-
ir_visited.insert(it->second);
270-
}
271-
while (!drr_q.empty()) {
272-
if (!matched) break;
273-
auto* drr_node = drr_q.front();
274-
auto* ir_node = ir_q.front();
275-
drr_q.pop();
276-
ir_q.pop();
261+
// Initialize DRR matched queue.
262+
const auto& InitDrrQueue = [&]() -> void {
263+
for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) {
264+
VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @"
265+
<< it->second << ") in source_pattern_graph ";
266+
drr_q.push(it->first);
267+
drr_visited.insert(it->first);
268+
ir_q.push(it->second);
269+
ir_visited.insert(it->second);
270+
}
271+
};
272+
// Check whether DrrNode and Operation have the same Operands and Results
273+
// information.
274+
const auto& IsSameOperandsAndResults =
275+
[](const OpCall* drr_node, const pir::Operation* ir_node) -> bool {
277276
if (drr_node->name() != ir_node->name()) {
278-
matched = false;
279277
VLOG(8) << "Match failed: drr_node(" << drr_node->name()
280278
<< ") != pir_node(" << ir_node->name() << ").";
281-
break;
279+
return false;
282280
}
283281
const auto& drr_input_tensors = drr_node->inputs();
284282
auto ir_input_value_size = ir_node->num_operands();
285283
if (drr_input_tensors.size() != ir_input_value_size) {
286-
matched = false;
287284
VLOG(8) << drr_node->name() << " Match failed: drr input tensors("
288285
<< drr_input_tensors.size() << ") != pir input tensors("
289286
<< ir_input_value_size << ").";
290-
break;
287+
return false;
291288
}
292289
if (drr_node->outputs().size() != ir_node->num_results()) {
293-
matched = false;
294290
VLOG(8) << drr_node->name() << " Match failed: drr output tensors("
295291
<< drr_node->outputs().size() << ") != pir output tensors("
296292
<< ir_node->num_results() << ").";
293+
return false;
294+
}
295+
return true;
296+
};
297+
// Check whether source_pattern_match_ctx has visited Operation's Operands.
298+
const auto& HasVisitedOperands = [&](const Tensor* drr_input_tensor,
299+
pir::Value ir_value) -> bool {
300+
const auto& tensor_name = drr_input_tensor->name();
301+
if (ir_value.isa<pir::BlockArgument>()) {
302+
VLOG(8) << "Match Attention! Found BlockArgument as input of "
303+
<< tensor_name;
304+
}
305+
return source_pattern_match_ctx->tensor_map().count(tensor_name) != 0 &&
306+
ir_value != source_pattern_match_ctx->tensor_map().at(tensor_name);
307+
};
308+
// Update drr_q et.al information. Return false if faild.
309+
const auto& TryUpdateDrrQueue = [&](const OpCall* drr_producer_op,
310+
pir::Operation* ir_producer_op) -> bool {
311+
// still return true if both visited.
312+
if (drr_visited.count(drr_producer_op) &&
313+
ir_visited.count(ir_producer_op)) {
314+
return true;
315+
}
316+
// insert map if both not visited.
317+
if (!drr_visited.count(drr_producer_op) &&
318+
!ir_visited.count(ir_producer_op)) {
319+
drr_q.push(drr_producer_op);
320+
ir_q.push(ir_producer_op);
321+
drr_visited.insert(drr_producer_op);
322+
ir_visited.insert(ir_producer_op);
323+
return true;
324+
}
325+
return false;
326+
};
327+
328+
// Step 1: Initialize DRR matched queue.
329+
bool matched = true;
330+
size_t step = 0;
331+
InitDrrQueue();
332+
333+
while (!drr_q.empty()) {
334+
if (!matched) break;
335+
auto* drr_node = drr_q.front();
336+
auto* ir_node = ir_q.front();
337+
drr_q.pop();
338+
ir_q.pop();
339+
if (!IsSameOperandsAndResults(drr_node, ir_node)) {
340+
matched = false;
297341
break;
298342
}
343+
// Step 1: Bind Operation of current op to match_ctx.
299344
source_pattern_match_ctx->BindIrOperation(drr_node, ir_node);
300-
// binding input_tensor of current_op
345+
346+
// Step 2: Bind input_tensor of current op to match_ctx.
347+
const auto& drr_input_tensors = drr_node->inputs();
348+
auto ir_input_values = ir_node->operands_source();
301349
for (size_t i = 0; i < drr_input_tensors.size(); ++i) {
302-
if (source_pattern_match_ctx->tensor_map().count(
303-
drr_input_tensors[i]->name()) != 0 &&
304-
ir_node->operand(i).source() !=
305-
source_pattern_match_ctx->tensor_map().at(
306-
drr_input_tensors[i]->name())) {
350+
if (HasVisitedOperands(drr_input_tensors[i], ir_input_values[i])) {
307351
matched = false;
308352
VLOG(8) << " tensor_map key[" << drr_input_tensors[i]->name()
309353
<< "] already exists,but value is different!";
310354
break;
311-
} else {
312-
source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(),
313-
ir_node->operand(i).source());
314-
}
315-
316-
if (ir_node->operand_source(i).isa<pir::BlockArgument>()) {
317-
VLOG(8) << "Match Attention! Found BlockArgument as input of "
318-
<< drr_node->name();
319355
}
320-
356+
source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(),
357+
ir_input_values[i]);
358+
// Skip it while drr_producer_op is nullptr for trigger pattern boundary.
321359
auto* drr_producer_op = drr_input_tensors[i]->producer();
322360
if (drr_producer_op == nullptr) {
323361
continue;
324362
}
325-
363+
// Check whether tensor and value have the same use_count.
326364
if (drr_input_tensors[i]->consumers().size() !=
327-
ir_node->operand(i).source().use_count()) {
365+
ir_input_values[i].use_count()) {
328366
matched = false;
329367
VLOG(8) << drr_node->name() << " Match failed: consumers of drr intput["
330368
<< i << "] { " << drr_node->outputs().size()
331369
<< " } != consumers of pir intput[" << i << "] { "
332-
<< ir_node->operand(i).source().use_count() << " }.";
370+
<< ir_input_values[i].use_count() << " }.";
333371
break;
334372
}
335373

336-
auto* ir_producer_op = ir_node->operand_source(i).defining_op();
337-
// bfs producer_op of current_op
338-
if (drr_visited.count(drr_producer_op) &&
339-
ir_visited.count(ir_producer_op)) {
340-
continue;
374+
auto* ir_producer_op = ir_input_values[i].defining_op();
375+
// Tigger early stop while operand is BlockArgument with
376+
// producer_op==nullptr.
377+
if (drr_producer_op && ir_producer_op == nullptr) {
378+
matched = false;
379+
break;
341380
}
342-
343-
if (!drr_visited.count(drr_producer_op) &&
344-
!ir_visited.count(ir_producer_op)) {
345-
drr_q.push(drr_producer_op);
346-
ir_q.push(ir_producer_op);
347-
drr_visited.insert(drr_producer_op);
348-
ir_visited.insert(ir_producer_op);
349-
} else {
381+
// bfs producer_op of current_op
382+
if (!TryUpdateDrrQueue(drr_producer_op, ir_producer_op)) {
350383
matched = false;
351384
VLOG(8) << "Match failed: status of visiting for" << drr_node->name()
352385
<< " is different.";

0 commit comments

Comments
 (0)