Skip to content

Commit 4bedcd0

Browse files
authored
[PIR] Mark ShareData as an inplace OP and fix inplace pass (#64195)
1 parent 8a612f3 commit 4bedcd0

11 files changed

Lines changed: 135 additions & 28 deletions

File tree

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,12 +2976,12 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber {
29762976
struct ShareBufferOpTranscriber : public OpTranscriber {
29772977
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
29782978
const OpDesc& op_desc) override {
2979-
std::string target_op_name = dialect::ShareDataOp::name();
2979+
std::string target_op_name = dialect::ShareData_Op::name();
29802980
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
29812981
if (!op_info) {
29822982
PADDLE_THROW(phi::errors::InvalidArgument(
29832983
"Op share_buffer should have corresponding OpInfo "
2984-
"pd_op.share_data"));
2984+
"pd_op.share_data_"));
29852985
}
29862986

29872987
return op_info;

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
'print',
110110
'number_count',
111111
'assign_value',
112-
'share_data',
112+
'share_data_',
113113
'onednn_to_paddle_layout',
114114
'lrn',
115115
'multi_gru',

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,7 @@
16271627
func: shadow_feed_tensors
16281628
param: [x]
16291629

1630-
- op : share_data
1630+
- op : share_data_
16311631
args : (Tensor x)
16321632
output : Tensor(out)
16331633
infer_meta:
@@ -1636,6 +1636,7 @@
16361636
kernel:
16371637
func: share_data
16381638
param: [x]
1639+
inplace : (x -> out)
16391640

16401641
- op : shuffle_batch
16411642
args : (Tensor x, Tensor seed, int startup_seed=0)

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ const std::unordered_set<std::string> LegacyOpList = {
6767
CSplitOp::name(),
6868
PushDenseOp::name(),
6969
SeedOp::name(),
70-
ShareDataOp::name(),
70+
ShareData_Op::name(),
7171
SparseMomentumOp::name(),
7272
GetTensorFromSelectedRowsOp::name(),
7373
RankAttentionOp::name(),

paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ class AutoMixedPrecisionPass : public pir::Pass {
571571
return;
572572
}
573573

574-
// Rewrite ShareDataOp
575-
if (op->isa<paddle::dialect::ShareDataOp>() && OpRunLowPrecision(op)) {
574+
// Rewrite ShareData_Op
575+
if (op->isa<paddle::dialect::ShareData_Op>() && OpRunLowPrecision(op)) {
576576
SetResultDataType(op->result(0), precision_mode_, builder.ir_context());
577577
return;
578578
}

paddle/fluid/pir/transforms/general/inplace_pass.cc

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ bool CanBeDeleted(pir::Value value) {
6767
return !(persist_attr && persist_attr.data());
6868
}
6969

70+
bool HasNotUser(const pir::Value& value,
71+
const std::unordered_map<pir::Value, size_t>& use_count_map,
72+
const std::unordered_map<pir::Value, pir::Value>& inplace_map) {
73+
auto current_value = value;
74+
while (use_count_map.at(current_value) == 0) {
75+
if (inplace_map.count(current_value) == 0) {
76+
return false;
77+
}
78+
current_value = inplace_map.at(current_value);
79+
}
80+
return true;
81+
}
82+
7083
bool CanDoInplace(const std::unordered_set<pir::Value>& eager_dels,
7184
pir::Value input,
7285
pir::Value output,
@@ -295,15 +308,25 @@ GetEagerDeletionValues(const pir::Block& block) {
295308
std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
296309
const pir::Block& block) {
297310
const auto eager_dels = GetEagerDeletionValues(block);
311+
auto use_count_map = [](const pir::Block& block) {
312+
std::unordered_map<pir::Value, size_t> use_count_map;
313+
for (auto& op : block) {
314+
for (auto value : op.results()) {
315+
use_count_map[value] = value.use_count();
316+
}
317+
}
318+
return use_count_map;
319+
}(block);
320+
std::unordered_map<pir::Value, pir::Value> inplace_map;
321+
298322
std::unordered_map<pir::Operation*, std::string> inplace_ops;
299323

300324
std::unordered_set<pir::Value> visited_values;
301-
std::unordered_set<pir::Value> reused_input_values;
302-
std::unordered_set<pir::Value> reused_output_values;
303325

304326
for (auto& op : block) {
305327
for (size_t i = 0; i < op.num_operands(); ++i) {
306328
visited_values.insert(op.operand_source(i));
329+
use_count_map[op.operand_source(i)]--;
307330
}
308331

309332
if (op.dialect()->name().compare(paddle::dialect::KernelDialect::name()) !=
@@ -339,11 +362,18 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
339362
if (upper_op_attrs.count("is_inplace") != 0 &&
340363
upper_op_attrs.at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
341364
VLOG(6) << upper_op_name << " is already an inplace op.";
342-
for (size_t i = 0; i < op.num_operands(); ++i) {
343-
reused_input_values.insert(op.operand_source(i));
365+
auto op_info =
366+
pir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name);
367+
auto op_yaml_interface =
368+
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
369+
paddle::dialect::OpYamlInfoParser op_info_parser(
370+
op_yaml_interface->get_op_info_(upper_op_name));
371+
for (auto [out_slot, in_slot] : op_info_parser.GetInplaceIdMap()) {
372+
auto out_value = op.result(out_slot);
373+
auto in_value = op.operand_source(in_slot);
374+
inplace_map[out_value] = in_value;
344375
}
345376
for (auto& result : op.results()) {
346-
reused_output_values.insert(result);
347377
visited_values.insert(result);
348378
}
349379
continue;
@@ -409,8 +439,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
409439
upper_op_name)) ||
410440
(visited_values.count(op.result(out_slot)) > 0) ||
411441
(!CanBeDeleted(op.result(out_slot))) ||
412-
(reused_input_values.count(op.operand_source(in_slot)) > 0) ||
413-
(reused_output_values.count(op.result(out_slot)) > 0) ||
442+
HasNotUser(op.operand_source(in_slot), use_count_map, inplace_map) ||
414443
(std::find(used_external_values.begin(),
415444
used_external_values.end(),
416445
op.operand_source(in_slot)) !=
@@ -435,19 +464,16 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
435464
<< " -- result " << out_slot
436465
<< " visited: " << (visited_values.count(op.result(out_slot)) > 0);
437466
VLOG_IF(8, in_slot < op.num_operands())
438-
<< " -- operand " << in_slot << " has been reused: "
439-
<< (reused_input_values.count(op.operand_source(in_slot)) > 0);
440-
VLOG_IF(8, out_slot < op.num_results())
441-
<< " -- result " << out_slot << " has been reused: "
442-
<< (reused_output_values.count(op.result(out_slot)) > 0);
467+
<< " -- operand " << in_slot << " has not user: "
468+
<< HasNotUser(
469+
op.operand_source(in_slot), use_count_map, inplace_map);
443470
break;
444471
}
445472
}
446473
if (can_do_inplace) {
447474
inplace_ops[&op] = upper_op_name + "_";
448475
for (auto& kv : inplace_out_2_in) {
449-
reused_input_values.insert(op.operand_source(kv.second));
450-
reused_output_values.insert(op.result(kv.first));
476+
inplace_map[op.result(kv.first)] = op.operand_source(kv.second);
451477
}
452478
VLOG(6) << upper_op_name
453479
<< " will change to inplace version op: " << upper_op_name + "_";

paddle/fluid/pybind/pir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ void BindValue(py::module *m) {
11561156
auto share_data_op =
11571157
ApiBuilder::Instance()
11581158
.GetBuilder()
1159-
->Build<paddle::dialect::ShareDataOp>(self);
1159+
->Build<paddle::dialect::ShareData_Op>(self);
11601160
auto out = share_data_op.out();
11611161
out.set_attribute(
11621162
kAttrStopGradients,

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4179,7 +4179,7 @@
41794179
data_type : int64_t
41804180
tensors_name : StepsTensorList
41814181

4182-
- op: share_data
4182+
- op: share_data_ (share_data)
41834183
inputs :
41844184
x : X
41854185
outputs :

test/deprecated/ir/pir/test_special_op_translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,8 @@ def test_program(self):
515515
)
516516
l = pir.translate_to_pir(main_program.desc)
517517
assert (
518-
l.global_block().ops[2].name() == "pd_op.share_data"
519-
), "share_buffer should be translated to share_data"
518+
l.global_block().ops[2].name() == "pd_op.share_data_"
519+
), "share_buffer should be translated to share_data_"
520520

521521

522522
class TestDataOp(unittest.TestCase):
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from dygraph_to_static_utils import (
19+
Dy2StTestBase,
20+
test_legacy_and_pt_and_pir,
21+
)
22+
23+
import paddle
24+
25+
26+
def detach_fn(x, y):
27+
u = x + y
28+
v = u.detach()
29+
o1 = v + 1
30+
31+
return o1, u
32+
33+
34+
class TestDetach(Dy2StTestBase):
35+
@test_legacy_and_pt_and_pir
36+
def test_detach(self):
37+
static_fn = paddle.jit.to_static(detach_fn)
38+
x = paddle.ones([], 'float32')
39+
y = paddle.ones([], 'float32')
40+
static_res = static_fn(x, y)
41+
dygraph_res = detach_fn(x, y)
42+
np.testing.assert_allclose(
43+
static_res[0].numpy(), dygraph_res[0].numpy()
44+
)
45+
np.testing.assert_allclose(
46+
static_res[1].numpy(), dygraph_res[1].numpy()
47+
)
48+
49+
50+
if __name__ == '__main__':
51+
unittest.main()

0 commit comments

Comments
 (0)