From 98aaf13ddfb1e16fa5c976f7632fe64a51fc7608 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 22 Apr 2024 17:09:17 +0000 Subject: [PATCH 1/3] [Dy2St][PIR] Re-create ShadowOutput OP in split forward-backward --- paddle/fluid/pybind/pir.cc | 41 +++++++++++++++----------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index cce5f045a722d5..01e206f783278d 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1099,15 +1099,26 @@ std::list::const_iterator list_offset(const Block *block, return it; } -template -void range_block_do(const Block *block, std::vector range, F fn) { +template +void range_block_do(const Block *block, + std::vector range, + F fn, + S skip_fn) { for (auto it = list_offset(block, range[0]); it != list_offset(block, range[1]); ++it) { + if (skip_fn(*it)) { + continue; + } fn(*it); } } +template +void range_block_do(const Block *block, std::vector range, F fn) { + range_block_do(block, range, fn, [](Operation *op) { return false; }); +} + template bool ExistsInMapValues(const std::map &m, V value) { for (const auto &[k, v] : m) { @@ -1461,7 +1472,9 @@ SplitedResult SplitForwardBackward( [&forward_mapper, &forward_program, &clone_options](Operation *op) { auto *cloned_op = op->Clone(forward_mapper, clone_options); forward_program->block()->push_back(cloned_op); - }); + }, + // Skip the ShadowOutputOp. + /*skip_fn=*/[](Operation *op) { return op->isa(); }); auto &forward_value_map = forward_mapper.GetMutableMap(); // backward program construct. @@ -1500,30 +1513,8 @@ SplitedResult SplitForwardBackward( forward_params.end()) { return; } - // NOTE(Aurelius84): we should skip insert ShadowOutputOp repeatedly by - // calling SplitForwardBackward multi-times. std::string shadow_output_name = std::string("output_") + std::to_string(counter); - std::unordered_set inserted_value; - for (auto it = forward_program->block()->rbegin(); - it != forward_program->block()->rend(); - ++it) { - if (it->isa()) { - auto out_name = - it->attribute("output_name").AsString(); - if (out_name == shadow_output_name) { - VLOG(4) << out_name - << " has been inserted ShadowOutputOp, skip it now."; - return; - } - - inserted_value.insert(it->operand_source(0)); - } - } - - if (inserted_value.count(forward_value_map[v])) { - return; - } auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); pir::AttributeMap attribute_map = { {"output_name", pir::StrAttribute::get(ctx, shadow_output_name)}, From a6ae483f0c15ad4976db99b961d74e317b30a7b0 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 22 Apr 2024 17:17:56 +0000 Subject: [PATCH 2/3] open the ut --- test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py index 2e266168892cfd..b8748500821e3b 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py @@ -101,7 +101,7 @@ def train(self, net, to_static, with_prim=False, with_cinn=False): def test_ast_prim_cinn(self): st_out = self.train(self.net, to_static=True) cinn_out = self.train( - self.net, to_static=True, with_prim=False, with_cinn=False + self.net, to_static=True, with_prim=True, with_cinn=False ) for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out) From b3090eb827ac327366830db3cc1d186d79eaf4bd Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 22 Apr 2024 19:10:22 +0000 Subject: [PATCH 3/3] skip fp only --- paddle/fluid/pybind/pir.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 01e206f783278d..07e497eb6bb3d1 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1506,11 +1506,9 @@ SplitedResult SplitForwardBackward( if (v.impl() == nullptr) { return; } - // Skip the value that already in forward_inputs or forward_params. - if (std::find(forward_inputs.begin(), forward_inputs.end(), v) != - forward_inputs.end() || - std::find(forward_params.begin(), forward_params.end(), v) != - forward_params.end()) { + // Skip the value that already in forward_params. + if (std::find(forward_params.begin(), forward_params.end(), v) != + forward_params.end()) { return; } std::string shadow_output_name =