@@ -67,6 +67,19 @@ bool CanBeDeleted(pir::Value value) {
6767 return !(persist_attr && persist_attr.data ());
6868}
6969
70+ bool HasNotUser (const pir::Value& value,
71+ const std::unordered_map<pir::Value, size_t >& use_count_map,
72+ const std::unordered_map<pir::Value, pir::Value>& inplace_map) {
73+ auto current_value = value;
74+ while (use_count_map.at (current_value) == 0 ) {
75+ if (inplace_map.count (current_value) == 0 ) {
76+ return false ;
77+ }
78+ current_value = inplace_map.at (current_value);
79+ }
80+ return true ;
81+ }
82+
7083bool CanDoInplace (const std::unordered_set<pir::Value>& eager_dels,
7184 pir::Value input,
7285 pir::Value output,
@@ -295,15 +308,25 @@ GetEagerDeletionValues(const pir::Block& block) {
295308std::unordered_map<pir::Operation*, std::string> GetInplaceOps (
296309 const pir::Block& block) {
297310 const auto eager_dels = GetEagerDeletionValues (block);
311+ auto use_count_map = [](const pir::Block& block) {
312+ std::unordered_map<pir::Value, size_t > use_count_map;
313+ for (auto & op : block) {
314+ for (auto value : op.results ()) {
315+ use_count_map[value] = value.use_count ();
316+ }
317+ }
318+ return use_count_map;
319+ }(block);
320+ std::unordered_map<pir::Value, pir::Value> inplace_map;
321+
298322 std::unordered_map<pir::Operation*, std::string> inplace_ops;
299323
300324 std::unordered_set<pir::Value> visited_values;
301- std::unordered_set<pir::Value> reused_input_values;
302- std::unordered_set<pir::Value> reused_output_values;
303325
304326 for (auto & op : block) {
305327 for (size_t i = 0 ; i < op.num_operands (); ++i) {
306328 visited_values.insert (op.operand_source (i));
329+ use_count_map[op.operand_source (i)]--;
307330 }
308331
309332 if (op.dialect ()->name ().compare (paddle::dialect::KernelDialect::name ()) !=
@@ -339,11 +362,18 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
339362 if (upper_op_attrs.count (" is_inplace" ) != 0 &&
340363 upper_op_attrs.at (" is_inplace" ).dyn_cast <pir::BoolAttribute>().data ()) {
341364 VLOG (6 ) << upper_op_name << " is already an inplace op." ;
342- for (size_t i = 0 ; i < op.num_operands (); ++i) {
343- reused_input_values.insert (op.operand_source (i));
365+ auto op_info =
366+ pir::IrContext::Instance ()->GetRegisteredOpInfo (upper_op_name);
367+ auto op_yaml_interface =
368+ op_info.GetInterfaceImpl <paddle::dialect::OpYamlInfoInterface>();
369+ paddle::dialect::OpYamlInfoParser op_info_parser (
370+ op_yaml_interface->get_op_info_ (upper_op_name));
371+ for (auto [out_slot, in_slot] : op_info_parser.GetInplaceIdMap ()) {
372+ auto out_value = op.result (out_slot);
373+ auto in_value = op.operand_source (in_slot);
374+ inplace_map[out_value] = in_value;
344375 }
345376 for (auto & result : op.results ()) {
346- reused_output_values.insert (result);
347377 visited_values.insert (result);
348378 }
349379 continue ;
@@ -409,8 +439,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
409439 upper_op_name)) ||
410440 (visited_values.count (op.result (out_slot)) > 0 ) ||
411441 (!CanBeDeleted (op.result (out_slot))) ||
412- (reused_input_values.count (op.operand_source (in_slot)) > 0 ) ||
413- (reused_output_values.count (op.result (out_slot)) > 0 ) ||
442+ HasNotUser (op.operand_source (in_slot), use_count_map, inplace_map) ||
414443 (std::find (used_external_values.begin (),
415444 used_external_values.end (),
416445 op.operand_source (in_slot)) !=
@@ -435,19 +464,16 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
435464 << " -- result " << out_slot
436465 << " visited: " << (visited_values.count (op.result (out_slot)) > 0 );
437466 VLOG_IF (8 , in_slot < op.num_operands ())
438- << " -- operand " << in_slot << " has been reused: "
439- << (reused_input_values.count (op.operand_source (in_slot)) > 0 );
440- VLOG_IF (8 , out_slot < op.num_results ())
441- << " -- result " << out_slot << " has been reused: "
442- << (reused_output_values.count (op.result (out_slot)) > 0 );
467+ << " -- operand " << in_slot << " has not user: "
468+ << HasNotUser (
469+ op.operand_source (in_slot), use_count_map, inplace_map);
443470 break ;
444471 }
445472 }
446473 if (can_do_inplace) {
447474 inplace_ops[&op] = upper_op_name + " _" ;
448475 for (auto & kv : inplace_out_2_in) {
449- reused_input_values.insert (op.operand_source (kv.second ));
450- reused_output_values.insert (op.result (kv.first ));
476+ inplace_map[op.result (kv.first )] = op.operand_source (kv.second );
451477 }
452478 VLOG (6 ) << upper_op_name
453479 << " will change to inplace version op: " << upper_op_name + " _" ;
0 commit comments