@@ -258,95 +258,128 @@ bool DrrRewritePattern::MatchFromOutputToInput(
258258 std::unordered_set<pir::Operation*> ir_visited;
259259 std::queue<const OpCall*> drr_q;
260260 std::queue<pir::Operation*> ir_q;
261- bool matched = true ;
262- size_t step = 0 ;
263- for (auto it = output_op_map.begin (); it != output_op_map.end (); ++it) {
264- VLOG (6 ) << " match (" << it->first ->name () << " @" << it->first << " : @"
265- << it->second << " ) in source_pattern_graph " ;
266- drr_q.push (it->first );
267- drr_visited.insert (it->first );
268- ir_q.push (it->second );
269- ir_visited.insert (it->second );
270- }
271- while (!drr_q.empty ()) {
272- if (!matched) break ;
273- auto * drr_node = drr_q.front ();
274- auto * ir_node = ir_q.front ();
275- drr_q.pop ();
276- ir_q.pop ();
261+ // Initialize DRR matched queue.
262+ const auto & InitDrrQueue = [&]() -> void {
263+ for (auto it = output_op_map.begin (); it != output_op_map.end (); ++it) {
264+ VLOG (6 ) << " match (" << it->first ->name () << " @" << it->first << " : @"
265+ << it->second << " ) in source_pattern_graph " ;
266+ drr_q.push (it->first );
267+ drr_visited.insert (it->first );
268+ ir_q.push (it->second );
269+ ir_visited.insert (it->second );
270+ }
271+ };
272+ // Check whether DrrNode and Operation have the same Operands and Results
273+ // information.
274+ const auto & IsSameOperandsAndResults =
275+ [](const OpCall* drr_node, const pir::Operation* ir_node) -> bool {
277276 if (drr_node->name () != ir_node->name ()) {
278- matched = false ;
279277 VLOG (8 ) << " Match failed: drr_node(" << drr_node->name ()
280278 << " ) != pir_node(" << ir_node->name () << " )." ;
281- break ;
279+ return false ;
282280 }
283281 const auto & drr_input_tensors = drr_node->inputs ();
284282 auto ir_input_value_size = ir_node->num_operands ();
285283 if (drr_input_tensors.size () != ir_input_value_size) {
286- matched = false ;
287284 VLOG (8 ) << drr_node->name () << " Match failed: drr input tensors("
288285 << drr_input_tensors.size () << " ) != pir input tensors("
289286 << ir_input_value_size << " )." ;
290- break ;
287+ return false ;
291288 }
292289 if (drr_node->outputs ().size () != ir_node->num_results ()) {
293- matched = false ;
294290 VLOG (8 ) << drr_node->name () << " Match failed: drr output tensors("
295291 << drr_node->outputs ().size () << " ) != pir output tensors("
296292 << ir_node->num_results () << " )." ;
293+ return false ;
294+ }
295+ return true ;
296+ };
297+ // Check whether source_pattern_match_ctx has visited Operation's Operands.
298+ const auto & HasVisitedOperands = [&](const Tensor* drr_input_tensor,
299+ pir::Value ir_value) -> bool {
300+ const auto & tensor_name = drr_input_tensor->name ();
301+ if (ir_value.isa <pir::BlockArgument>()) {
302+ VLOG (8 ) << " Match Attention! Found BlockArgument as input of "
303+ << tensor_name;
304+ }
305+ return source_pattern_match_ctx->tensor_map ().count (tensor_name) != 0 &&
306+ ir_value != source_pattern_match_ctx->tensor_map ().at (tensor_name);
307+ };
308+ // Update drr_q et.al information. Return false if faild.
309+ const auto & TryUpdateDrrQueue = [&](const OpCall* drr_producer_op,
310+ pir::Operation* ir_producer_op) -> bool {
311+ // still return true if both visited.
312+ if (drr_visited.count (drr_producer_op) &&
313+ ir_visited.count (ir_producer_op)) {
314+ return true ;
315+ }
316+ // insert map if both not visited.
317+ if (!drr_visited.count (drr_producer_op) &&
318+ !ir_visited.count (ir_producer_op)) {
319+ drr_q.push (drr_producer_op);
320+ ir_q.push (ir_producer_op);
321+ drr_visited.insert (drr_producer_op);
322+ ir_visited.insert (ir_producer_op);
323+ return true ;
324+ }
325+ return false ;
326+ };
327+
328+ // Step 1: Initialize DRR matched queue.
329+ bool matched = true ;
330+ size_t step = 0 ;
331+ InitDrrQueue ();
332+
333+ while (!drr_q.empty ()) {
334+ if (!matched) break ;
335+ auto * drr_node = drr_q.front ();
336+ auto * ir_node = ir_q.front ();
337+ drr_q.pop ();
338+ ir_q.pop ();
339+ if (!IsSameOperandsAndResults (drr_node, ir_node)) {
340+ matched = false ;
297341 break ;
298342 }
343+ // Step 1: Bind Operation of current op to match_ctx.
299344 source_pattern_match_ctx->BindIrOperation (drr_node, ir_node);
300- // binding input_tensor of current_op
345+
346+ // Step 2: Bind input_tensor of current op to match_ctx.
347+ const auto & drr_input_tensors = drr_node->inputs ();
348+ auto ir_input_values = ir_node->operands_source ();
301349 for (size_t i = 0 ; i < drr_input_tensors.size (); ++i) {
302- if (source_pattern_match_ctx->tensor_map ().count (
303- drr_input_tensors[i]->name ()) != 0 &&
304- ir_node->operand (i).source () !=
305- source_pattern_match_ctx->tensor_map ().at (
306- drr_input_tensors[i]->name ())) {
350+ if (HasVisitedOperands (drr_input_tensors[i], ir_input_values[i])) {
307351 matched = false ;
308352 VLOG (8 ) << " tensor_map key[" << drr_input_tensors[i]->name ()
309353 << " ] already exists,but value is different!" ;
310354 break ;
311- } else {
312- source_pattern_match_ctx->BindIrValue (drr_input_tensors[i]->name (),
313- ir_node->operand (i).source ());
314- }
315-
316- if (ir_node->operand_source (i).isa <pir::BlockArgument>()) {
317- VLOG (8 ) << " Match Attention! Found BlockArgument as input of "
318- << drr_node->name ();
319355 }
320-
356+ source_pattern_match_ctx->BindIrValue (drr_input_tensors[i]->name (),
357+ ir_input_values[i]);
358+ // Skip it while drr_producer_op is nullptr for trigger pattern boundary.
321359 auto * drr_producer_op = drr_input_tensors[i]->producer ();
322360 if (drr_producer_op == nullptr ) {
323361 continue ;
324362 }
325-
363+ // Check whether tensor and value have the same use_count.
326364 if (drr_input_tensors[i]->consumers ().size () !=
327- ir_node-> operand (i). source () .use_count ()) {
365+ ir_input_values[i] .use_count ()) {
328366 matched = false ;
329367 VLOG (8 ) << drr_node->name () << " Match failed: consumers of drr intput["
330368 << i << " ] { " << drr_node->outputs ().size ()
331369 << " } != consumers of pir intput[" << i << " ] { "
332- << ir_node-> operand (i). source () .use_count () << " }." ;
370+ << ir_input_values[i] .use_count () << " }." ;
333371 break ;
334372 }
335373
336- auto * ir_producer_op = ir_node->operand_source (i).defining_op ();
337- // bfs producer_op of current_op
338- if (drr_visited.count (drr_producer_op) &&
339- ir_visited.count (ir_producer_op)) {
340- continue ;
374+ auto * ir_producer_op = ir_input_values[i].defining_op ();
375+ // Tigger early stop while operand is BlockArgument with
376+ // producer_op==nullptr.
377+ if (drr_producer_op && ir_producer_op == nullptr ) {
378+ matched = false ;
379+ break ;
341380 }
342-
343- if (!drr_visited.count (drr_producer_op) &&
344- !ir_visited.count (ir_producer_op)) {
345- drr_q.push (drr_producer_op);
346- ir_q.push (ir_producer_op);
347- drr_visited.insert (drr_producer_op);
348- ir_visited.insert (ir_producer_op);
349- } else {
381+ // bfs producer_op of current_op
382+ if (!TryUpdateDrrQueue (drr_producer_op, ir_producer_op)) {
350383 matched = false ;
351384 VLOG (8 ) << " Match failed: status of visiting for" << drr_node->name ()
352385 << " is different." ;
0 commit comments