From 001d799b26bcf436613132944fc0dda77508eab7 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 8 Dec 2023 09:23:14 +0000 Subject: [PATCH 01/21] optimize backward --- python/paddle/autograd/backward_utils.py | 12 +++++++- python/paddle/autograd/ir_backward.py | 38 ++++++++---------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index b423f6ed5e4bec..d4e88257e29f21 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -18,7 +18,8 @@ class State: """ record relationship of forward op/value and backward op/value - one state must be bining with a program + one state must be bining with a block, if block has parent block, + state will include parent block info. """ @@ -35,6 +36,10 @@ def __init__(self, block): self.sumvaluegrad_to_value = collections.defaultdict(list) # operation -> list(operation) self.opgrad_to_op = collections.defaultdict(list) + # only for controlflow + # inside_value is sub block value, which will yield to parent block, + # parant block value is outside_value + self.inside_value_to_outside_value_map = {} def turn_map(self) -> None: self.valuegrad_to_value = collections.defaultdict(list) @@ -67,4 +72,9 @@ def copy(self, new_block): # operation -> list(operation) state.opgrad_to_op = self.opgrad_to_op.copy() + # only for controlflow + state.inside_value_to_outside_value_map = ( + self.inside_value_to_outside_value_map.copy() + ) + return state diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index bf9c14845be9f1..6b4440808c514e 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -239,9 +239,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): total_ops[i] for i in range(len(total_ops)) if intersection_op_flags[i] ] uneffective_ops = [ - total_ops[i] - for i in reversed(range(len(total_ops))) - if not union_op_flags[i] + total_ops[i] for i in range(len(total_ops)) if not union_op_flags[i] ] return effective_ops, uneffective_ops @@ -337,7 +335,6 @@ def append_backward_ops( no_grad_set, backward_ops, state, - inside_value_to_outside_value_map, ): ''' add grad_op in order of topological inverse sort @@ -351,7 +348,7 @@ def append_backward_ops( v2_g = call_vjp(op3, [[v2]], [[v3]],[[v3_g]], [[v2_stopgradient]]) - special pattern 1: + special pattern: v11 -> combine_op -> v1 -> op -> v3 v12 -> v2 -> @@ -359,7 +356,7 @@ def append_backward_ops( v1 is inside python api, we don't describe it in backward process(state) so v1_grad is inside vjp, we don't describe it in backward process(state) - [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [[v11, v12]], [[v3]],[[v3_g]], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient]) + [[v11_g, v12_g], v2_g] = call_vjp(op, [[v11, v12]], [[v3]],[[v3_g]], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient]) op_vjp is: @@ -380,7 +377,7 @@ def append_backward_ops( ''' def append_add_n(value): - # one value is input of more than one fwd_op, + # value is input of more than one fwd_op, # so more than one bwd_op create input_grad, # need add sum op to accumulate gradient add_n_value = paddle.add_n( @@ -406,8 +403,8 @@ def make_output_with_output_grad(op): if value in control_flow_value_to_copyvalue_map else [value] ) - while value in inside_value_to_outside_value_map: - value = inside_value_to_outside_value_map[value] + while value in state.inside_value_to_outside_value_map: + value = state.inside_value_to_outside_value_map[value] if ( value in state.value_to_valuegrad @@ -425,7 +422,7 @@ def make_output_with_output_grad(op): ): # pattern case: # this fwd_op's output is vectorType, it will split to - # Type by builtin.split op, so need get from split op's ouput + # Type by builtin_split op, so need get from split op's outputs. ( split_zero_flag, split_outputs, @@ -556,8 +553,8 @@ def append_yield(block, base_inputs, base_inputs_grad): if value_grad is None: continue - while value in inside_value_to_outside_value_map: - value = inside_value_to_outside_value_map[value] + while value in state.inside_value_to_outside_value_map: + value = state.inside_value_to_outside_value_map[value] if value in state.value_to_valuegrad: if len(state.value_to_valuegrad[value]) > 1: @@ -579,8 +576,6 @@ def append_yield(block, base_inputs, base_inputs_grad): # -----------------only for control flow-----------------# # tuple_push value to pop value control_flow_value_to_copyvalue_map = {} - # tuple_push value to pop value - control_flow_copyvalue_to_value_map = {} if ( len(effective_forward_ops) > 1 @@ -590,7 +585,9 @@ def append_yield(block, base_inputs, base_inputs_grad): for outside_output, inside_output in zip( base_op.results(), yield_op.operands_source() ): - inside_value_to_outside_value_map[inside_output] = outside_output + state.inside_value_to_outside_value_map[ + inside_output + ] = outside_output forward_ops = effective_forward_ops[:-1] else: forward_ops = effective_forward_ops @@ -628,9 +625,6 @@ def append_yield(block, base_inputs, base_inputs_grad): control_flow_value_to_copyvalue_map[ output[0] ] = copy_output[0] - control_flow_copyvalue_to_value_map[ - copy_output[0] - ] = output[0] else: # all(zero_flag) support this op has no contribution for grad @@ -656,9 +650,6 @@ def append_yield(block, base_inputs, base_inputs_grad): op.blocks(), grad_op.blocks() ): sub_state = state.copy(sub_fwd_block) - sub_inside_value_to_outside_value_map = ( - inside_value_to_outside_value_map.copy() - ) sub_backward_ops = [] append_backward_ops( op, @@ -670,7 +661,6 @@ def append_yield(block, base_inputs, base_inputs_grad): no_grad_set, sub_backward_ops, sub_state, - sub_inside_value_to_outside_value_map, ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) @@ -809,9 +799,6 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): inputs, complete_outputs ) - # sub_block op output to parent_block op output - inside_value_to_outside_value_map = {} - append_backward_ops( None, None, @@ -822,7 +809,6 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): no_grad_set, backward_ops, state, - inside_value_to_outside_value_map, ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( From 88548964906a8442b5709e653c2eb75b344a74b2 Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Tue, 12 Dec 2023 08:56:31 +0000 Subject: [PATCH 02/21] [PIR] add vjp interface for while op --- .../dialect/operator/ir/control_flow_op.cc | 88 ++++++++++++++++++- .../pir/dialect/operator/ir/control_flow_op.h | 10 ++- paddle/fluid/pybind/pir.cc | 11 ++- paddle/pir/core/op_base.h | 3 + paddle/pir/core/operation.h | 1 + paddle/pir/dialect/control_flow/ir/cf_type.h | 1 + test/ir/pir/test_while_api.py | 36 ++++++++ 7 files changed, 143 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index dbb7c7c248dd48..12f5dc9148e832 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -20,6 +20,8 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp, paddle::dialect::HasElementsOp #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" @@ -289,6 +291,84 @@ void WhileOp::Print(pir::IrPrinter &printer) { os << "\n }"; } +std::vector> WhileOp::Vjp( + pir::Operation *op, + const std::vector> &inputs, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients) { + auto fwd_op = WhileOp::dyn_cast(op); + PADDLE_ENFORCE_NE( + fwd_op, + nullptr, + phi::errors::InvalidArgument("The input op used to called WhileOp::vjp " + "must be non-nullptr while_op")); + TuplePushOp push_op; + for (auto iter = fwd_op.body().rbegin(); iter != fwd_op.body().rend(); + ++iter) { + if (iter->isa()) { + push_op = iter->dyn_cast(); + PADDLE_ENFORCE_EQ(push_op.container().use_empty(), + true, + phi::errors::InvalidArgument( + "The last container in foward while op must used " + "empty while construct while_grad op")); + break; + } + } + PADDLE_ENFORCE_NE(push_op, + nullptr, + phi::errors::InvalidArgument( + "The forward WhileOp must include TuplePushOp, denying " + "that we can't construct a reverse loop condition.")); + + PADDLE_ENFORCE_GT(inputs.size(), + outputs.size(), + phi::errors::InvalidArgument( + "while op's inputs' size should greater than " + "outputs' size, Now the inputs's size is %d ." + "the outputs size is %d.", + inputs.size(), + outputs.size())); + PADDLE_ENFORCE_EQ(stop_gradients[0][0], + true, + phi::errors::InvalidArgument( + "The stop_gradient of condition input must be true.")); + + auto &builder = *ApiBuilder::Instance().GetBuilder(); + auto cond_val = builder.Build(push_op.container()).out(); + + std::vector output_types; + std::vector loop_vars; + size_t index = 0; + + for (; index < outputs.size(); ++index) { + if (!stop_gradients[index + 1][0]) { + loop_vars.push_back(out_grads[index][0]); + } + } + for (++index; index < inputs.size(); ++index) { + if (!stop_gradients[index][0]) { + auto fwd_type = inputs[index][0].type().dyn_cast(); + PADDLE_ENFORCE_NE( + fwd_type, + pir::Type(), + phi::errors::InvalidArgument( + "The forward value type must be dense tensor type.")); + auto shape = vectorize(fwd_type.dims()); + auto dtype = TransToPhiDataType(fwd_type.dtype()); + auto full_op = builder.Build(shape, 0.0, dtype, phi::CPUPlace()); + loop_vars.push_back(full_op.out()); + } + } + auto while_grad = builder.Build(cond_val, loop_vars); + + std::vector> res(inputs.size()); + for (size_t i = 0, j = 0; i < inputs.size(); ++i) { + res[i].push_back(stop_gradients[i][0] ? nullptr : while_grad.result(j++)); + } + return res; +} std::vector> TuplePushOpVjpInterfaceModel::Vjp( pir::Operation *op, const std::vector> &inputs, @@ -318,8 +398,8 @@ std::vector> TuplePushOpVjpInterfaceModel::Vjp( void HasElementsOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value stack) { - argument.AddInput(stack); + pir::Value container) { + argument.AddInput(container); argument.AddOutput( DenseTensorType::get(builder.ir_context(), builder.bool_type(), {1})); } @@ -327,8 +407,8 @@ void HasElementsOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs ,attributes for: HasElementsOp."; // Verify inputs: IR_ENFORCE(num_operands() == 1u, "The size of inputs must equal to 1."); - IR_ENFORCE(operand_source(0).type().isa(), - "The first input of cf.has_elements must be stack_type."); + IR_ENFORCE(operand_type(0).isa(), + "The first input of cf.has_elements must be container type."); // No attributes should be verify. diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 5c8354f06ffe53..231da23e289095 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -65,7 +65,7 @@ class IfOp : public pir::Op { /// cond, outputs = body(outputs) /// } /// -class WhileOp : public pir::Op { +class WhileOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_op.while"; } @@ -81,6 +81,12 @@ class WhileOp : public pir::Op { void Print(pir::IrPrinter &printer); // NOLINT void VerifySig() {} void VerifyRegion() {} + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); }; struct TuplePushOpVjpInterfaceModel : public VjpInterface::Concept { @@ -114,7 +120,7 @@ class HasElementsOp : public pir::Op { static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value stack); + pir::Value container); void VerifySig(); pir::Value input() { return operand_source(0); } pir::Value out() { return result(0); } diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index feb37821e58ef7..d0214ff71a9cb4 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -79,6 +79,7 @@ using paddle::dialect::DenseTensorArrayType; using paddle::dialect::DenseTensorType; using paddle::dialect::IfOp; using paddle::dialect::SelectedRowsType; +using paddle::dialect::WhileOp; using pir::Attribute; using pir::Block; @@ -497,7 +498,15 @@ void BindOperation(py::module *m) { self.ReplaceAllUsesWith(op_results); }) .def("as_if_op", - [](Operation &self) { return PyIfOp(self.dyn_cast()); }); + [](Operation &self) { return PyIfOp(self.dyn_cast()); }) + .def("as_while_op", [](Operation &self) -> WhileOp { + auto while_op = self.dyn_cast(); + if (!while_op) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Can't cast non-while type Operation to WhileOp.")); + } + return while_op; + }); py::class_ block_container( *m, "Operation_BlockContainer", R"DOC( The Operation_BlockContainer only use to walk all blocks in the operation. diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index f0b0451f86461d..5b9d76aa9255ca 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -62,6 +62,9 @@ class IR_API OpBase { Value operand_source(uint32_t index) const { return operation()->operand_source(index); } + Type operand_type(uint32_t index) const { + return operation()->operand_type(index); + } OpResult result(uint32_t index) const { return operation()->result(index); } diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 7f3c9e28932cd5..81d96174531438 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -100,6 +100,7 @@ class IR_API alignas(8) Operation final std::vector operands(); Value operand_source(uint32_t index) const; std::vector operands_source() const; + Type operand_type(uint32_t index) const { return operand(index).type(); } /// /// \brief op successor related public interfaces diff --git a/paddle/pir/dialect/control_flow/ir/cf_type.h b/paddle/pir/dialect/control_flow/ir/cf_type.h index 15e3b14280e272..47f52e4d2c4039 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_type.h +++ b/paddle/pir/dialect/control_flow/ir/cf_type.h @@ -21,6 +21,7 @@ namespace pir { class IR_API ContainerType : public Type { + public: using Type::Type; static bool classof(Type); }; diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 71d7f55573f9f3..6218d90950bbf4 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -15,6 +15,11 @@ import unittest import paddle +from paddle.base.core import call_vjp, has_vjp +from paddle.base.libpaddle.pir import ( + build_pipe_for_block, + get_used_external_value, +) paddle.enable_static() @@ -51,6 +56,37 @@ def test_while_base(self): self.assertEqual(last_op.name(), "pd_op.while") self.assertEqual(len(out), 2) + def test_while_op_vjp_interface(self): + main_program = self.construct_program_with_while() + while_op = main_program.global_block().ops[-1] + self.assertEqual(while_op.name(), "pd_op.while") + build_pipe_for_block(while_op.as_while_op().body()) + with paddle.pir.core.program_guard(main_program): + out_grad = paddle.full(shape=[6, 1], dtype='float32', fill_value=3) + # check vjp interface for while_op + while_input = [ + [input] for input in get_used_external_value(while_op) + ] + self.assertEqual(len(while_input), 4) + while_input_stop_graditents = [[True], [False], [True], [True]] + while_output = [while_op.results()] + while_output_grad = [[out_grad, out_grad]] + self.assertEqual(has_vjp(while_op), True) + grad_outs = call_vjp( + while_op, + while_input, + while_output, + while_output_grad, + while_input_stop_graditents, + ) + + self.assertEqual(grad_outs[0][0], None) + + while_grad_op = grad_outs[1][0].get_defining_op() + self.assertEqual(while_grad_op.name(), "pd_op.while") + while_grad_output = while_grad_op.results() + self.assertEqual(len(while_grad_output), 2) + if __name__ == "__main__": unittest.main() From 7e177f6032da0de499ae6da697fca2fb685827a8 Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Wed, 13 Dec 2023 11:27:24 +0000 Subject: [PATCH 03/21] [PIR] fix ci error. --- test/ir/pir/test_while_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 6218d90950bbf4..fc119eeafc5c33 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -85,7 +85,7 @@ def test_while_op_vjp_interface(self): while_grad_op = grad_outs[1][0].get_defining_op() self.assertEqual(while_grad_op.name(), "pd_op.while") while_grad_output = while_grad_op.results() - self.assertEqual(len(while_grad_output), 2) + self.assertEqual(len(while_grad_output), 1) if __name__ == "__main__": From 11c8656d15819365c51157128a2adc67c51d7141 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Thu, 14 Dec 2023 02:08:06 +0000 Subject: [PATCH 04/21] modify while stopgradient --- .../pir/dialect/operator/ir/control_flow_op.cc | 9 +++++++++ test/ir/pir/test_while_api.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index dbb7c7c248dd48..74a4c99f21f3fe 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -251,10 +251,19 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT argument.AddInput(cond); argument.AddInputs(inputs); auto &body = argument.AddRegion().emplace_back(); + std::vector outs_stop_gradient; for (auto val : inputs) { argument.AddOutput(val.type()); body.AddArgument(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); } + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(builder.ir_context(), outs_stop_gradient)); + + cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true)); } pir::Block &WhileOp::body() { pir::Region &body_region = (*this)->region(0); diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 71d7f55573f9f3..51d48a6bdbde7f 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -15,6 +15,7 @@ import unittest import paddle +from paddle.autograd.ir_backward import grad paddle.enable_static() @@ -31,7 +32,7 @@ def body(i, ten): return [i, ten] -class TestBuildModuleWithIfOp(unittest.TestCase): +class TestBuildModuleWithWhileOp(unittest.TestCase): def construct_program_with_while(self): main_program = paddle.static.Program() with paddle.pir.core.program_guard(main_program): @@ -41,6 +42,7 @@ def construct_program_with_while(self): ten = paddle.full( shape=[1], fill_value=10, dtype='int64' ) # loop length + i.stop_gradient = False i, ten = paddle.static.nn.while_loop(cond, body, [i, ten]) return main_program @@ -48,9 +50,21 @@ def test_while_base(self): main_program = self.construct_program_with_while() last_op = main_program.global_block().ops[-1] out = last_op.results() + self.assertEqual(out.stop_gradient, False) self.assertEqual(last_op.name(), "pd_op.while") self.assertEqual(len(out), 2) + def test_while_base_backward(self): + main_program = self.construct_program_with_while() + full_op1 = main_program.global_block().ops[0] + while_op = main_program.global_block().ops[-1] + with paddle.pir.core.program_guard(main_program): + out = while_op.result(0) + 1 + grad_outs = grad( + out, + [full_op1.result(0)], + ) + if __name__ == "__main__": unittest.main() From da62e16ec64b2d801671359509c26021c4955938 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 15 Dec 2023 01:30:37 +0000 Subject: [PATCH 05/21] merge --- .../dialect/operator/ir/control_flow_op.cc | 35 +++++++++---------- python/paddle/autograd/ir_backward.py | 19 +++++++++- test/cpp/pir/pass/pass_manager_test.cc | 2 +- test/ir/pir/test_while_api.py | 25 ++++++------- 4 files changed, 48 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 8644cc3704b422..99c3df6ceffb81 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -48,6 +48,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT argument.output_types.swap(output_types); argument.AddRegion().emplace_back(); argument.AddRegion().emplace_back(); + cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true)); } void IfOp::Build(pir::Builder &builder, // NOLINT @@ -256,8 +257,11 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT std::vector outs_stop_gradient; for (auto val : inputs) { argument.AddOutput(val.type()); - body.AddArgument(val.type()); + auto arg = body.AddArgument(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + arg.set_attribute(kStopGradientAttrName, + bool_attr ? bool_attr : builder.bool_attr(false)); outs_stop_gradient.push_back(bool_attr ? bool_attr : builder.bool_attr(false)); } @@ -339,6 +343,14 @@ std::vector> WhileOp::Vjp( "the outputs size is %d.", inputs.size(), outputs.size())); + PADDLE_ENFORCE_EQ(inputs.size(), + out_grads.size() + 1, + phi::errors::InvalidArgument( + "while op's inputs' size should equal to " + "output_grads' size, Now the inputs's size is %d ." + "the output_grads size is %d.", + inputs.size(), + out_grads.size())); PADDLE_ENFORCE_EQ(stop_gradients[0][0], true, phi::errors::InvalidArgument( @@ -350,25 +362,12 @@ std::vector> WhileOp::Vjp( std::vector output_types; std::vector loop_vars; - for (size_t index = 0; index < inputs.size(); ++index) { + for (size_t index = 0; index < out_grads.size(); ++index) { if (!stop_gradients[index + 1][0]) { loop_vars.push_back(out_grads[index][0]); } } - // for (++index; index < inputs.size(); ++index) { - // if (!stop_gradients[index][0]) { - // auto fwd_type = inputs[index][0].type().dyn_cast(); - // PADDLE_ENFORCE_NE( - // fwd_type, - // pir::Type(), - // phi::errors::InvalidArgument( - // "The forward value type must be dense tensor type.")); - // auto shape = vectorize(fwd_type.dims()); - // auto dtype = TransToPhiDataType(fwd_type.dtype()); - // auto full_op = builder.Build(shape, 0.0, dtype, - // phi::CPUPlace()); loop_vars.push_back(full_op.out()); - // } - // } + auto while_grad = builder.Build(cond_val, loop_vars); std::vector> res(inputs.size()); @@ -397,9 +396,7 @@ std::vector> TuplePushOpVjpInterfaceModel::Vjp( res[0].resize(1); for (size_t i = 1u; i < inputs.size(); ++i) { res[i].resize(1); - if (!stop_gradients[i][0]) { - res[i][0] = pop_op.result(i - 1); - } + res[i][0] = pop_op.result(i - 1); } return res; } diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index a65a4b7a35f1ed..80a10a9343552c 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -449,6 +449,23 @@ def make_output_with_output_grad(op): outputs.append(new_value) output_grads.append(state.value_to_valuegrad[value][0]) + if op.name() == "pd_op.while": + for i, input in enumerate(get_real_op_inputs(op)): + if i <= len(op.results()): + continue + if ( + input in state.value_to_valuegrad + and len(state.value_to_valuegrad[input]) > 1 + ): + append_add_n(input) + + if ( + input not in state.value_to_valuegrad + or state.value_to_valuegrad[input] == [] + ): + append_full_like(0.0, input, input, state, backward_ops) + output_grads.append(state.value_to_valuegrad[input][0]) + return zero_flag, outputs, output_grads def make_input_with_input_stopgradient(op): @@ -576,7 +593,7 @@ def append_yield(block, base_inputs, base_inputs_grad): # [op4] (op4's inputs and outputs are not vectorType) # -----------------only for control flow-----------------# - # tuple_push value to pop value + # tuple_push value to tuple_pop value control_flow_value_to_copyvalue_map = {} if ( diff --git a/test/cpp/pir/pass/pass_manager_test.cc b/test/cpp/pir/pass/pass_manager_test.cc index 7c00a5d24cb988..9b948032b38b7e 100644 --- a/test/cpp/pir/pass/pass_manager_test.cc +++ b/test/cpp/pir/pass/pass_manager_test.cc @@ -226,7 +226,7 @@ TEST(pass_manager, PassManager) { true, true)); - pm.EnablePassTiming(true); + // pm.EnablePassTiming(true); CHECK_EQ(pm.Run(&program), true); } diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 9313a53b757998..3dab5f23ca015c 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -72,8 +72,8 @@ def test_while_op_vjp_interface(self): ] self.assertEqual(len(while_input), 4) while_input_stop_graditents = [[True], [False], [True], [True]] - while_output = [while_op.results()] - while_output_grad = [[out_grad, out_grad]] + while_output = [[value] for value in while_op.results()] + while_output_grad = [[out_grad], [out_grad], [out_grad]] self.assertEqual(has_vjp(while_op), True) grad_outs = call_vjp( while_op, @@ -90,16 +90,17 @@ def test_while_op_vjp_interface(self): while_grad_output = while_grad_op.results() self.assertEqual(len(while_grad_output), 1) - def test_while_base_backward(self): - main_program = self.construct_program_with_while() - full_op1 = main_program.global_block().ops[0] - while_op = main_program.global_block().ops[-1] - with paddle.pir.core.program_guard(main_program): - out = while_op.result(0) + 1 - grad_outs = grad( - out, - [full_op1.result(0)], - ) + def test_while_base_backward(self): + main_program = self.construct_program_with_while() + full_op1 = main_program.global_block().ops[0] + while_op = main_program.global_block().ops[-1] + with paddle.pir.core.program_guard(main_program): + out = while_op.result(0) + 1 + grad_outs = grad( + out, + [full_op1.result(0)], + ) + print(main_program) if __name__ == "__main__": From 30bba329c658e2ee730863fecd04a9177d30f717 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Mon, 18 Dec 2023 01:52:59 +0000 Subject: [PATCH 06/21] modify while grad bug --- .../dialect/operator/ir/control_flow_op.cc | 4 ++ paddle/fluid/pybind/control_flow_api.cc | 28 ++++++++- paddle/fluid/pybind/pir.cc | 8 +++ python/paddle/autograd/ir_backward.py | 61 +++++++++++++------ python/paddle/static/nn/control_flow.py | 1 + test/ir/pir/test_while_api.py | 41 +++++++++++++ 6 files changed, 122 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index e12a95b7bf624d..d0df150e3ea8f8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -406,6 +406,10 @@ void HasElementsOp::Build(pir::Builder &builder, // NOLINT argument.AddInput(container); argument.AddOutput( DenseTensorType::get(builder.ir_context(), builder.bool_type(), {1})); + std::vector outs_stop_gradient{builder.bool_attr(true)}; + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } void HasElementsOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs ,attributes for: HasElementsOp."; diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index d6be95421c234c..dd815f70332a25 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -35,6 +35,7 @@ namespace py = pybind11; using paddle::dialect::ApiBuilder; +using paddle::dialect::HasElementsOp; using paddle::dialect::IfOp; using paddle::dialect::WhileOp; using pir::Block; @@ -126,6 +127,30 @@ std::vector GetUsedExternalValue(const Operation& op) { return used_values; } +Value BuildHasElementsOp(Operation& fwd_op) { // NOLINT + PADDLE_ENFORCE(fwd_op.isa(), + phi::errors::PreconditionNotMet( + "param op of BuildHasElementsOp must be while op.")); + auto fwdop = fwd_op.dyn_cast(); + TuplePushOp push_op; + for (auto iter = fwdop.body().rbegin(); iter != fwdop.body().rend(); ++iter) { + if (iter->isa()) { + push_op = iter->dyn_cast(); + PADDLE_ENFORCE_EQ(push_op.container().use_empty(), + false, + phi::errors::InvalidArgument( + "The last container in foward while op must used " + "after construct while_grad op")); + break; + } + } + auto new_cond = ApiBuilder::Instance() + .GetBuilder() + ->Build(push_op.container()) + .out(); + return new_cond; +} + void BuildPipeForBlock(Block* block) { PADDLE_ENFORCE_NOT_NULL( block, @@ -193,6 +218,7 @@ void PyIfOp::UpdateOutput() { void BindControlFlowApi(py::module* m) { m->def("get_used_external_value", GetUsedExternalValue); m->def("build_pipe_for_block", BuildPipeForBlock); + m->def("cf_has_elements", BuildHasElementsOp); m->def("cf_yield", [](py::list inputs) { std::vector input_values; for (auto input : inputs) { @@ -200,9 +226,9 @@ void BindControlFlowApi(py::module* m) { } ApiBuilder::Instance().GetBuilder()->Build(input_values); }); - BindIfOp(m); BindWhileOp(m); } + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index f0ddaac7a0db58..06f1ad4688e71c 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -697,6 +697,14 @@ void BindValue(py::module *m) { kAttrIsPersisable, BoolAttribute::get(pir::IrContext::Instance(), persistable)); }) + .def("all_used_ops", + [](Value &self) -> py::list { + py::list op_list; + for (auto it = self.use_begin(); it != self.use_end(); ++it) { + op_list.append(it.owner()); + } + return op_list; + }) .def( "get_defining_op", [](Value self) -> pir::Operation * { diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 80a10a9343552c..c65009dd97b0f2 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -81,6 +81,9 @@ def get_real_op_inputs(op): def update_no_grad_set_by_stopgradient(block, no_grad_set): for op in block.ops: + # if op.name() in ["pd_op.if" , "pd_op.while"]: + # for sub_block in op.blocks(): + # update_no_grad_set_by_stopgradient(sub_block, no_grad_set) for value in op.results(): if value.stop_gradient and value not in no_grad_set: no_grad_set.add(value) @@ -137,6 +140,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state): [feedop], ) state.value_to_valuegrad[output] = [[grad]] + # add input for bwd first op complete_outputs = outputs complete_gradoutputs = grad_outputs @@ -255,6 +259,12 @@ def update_no_grad_set_after_prune( from outputs to inputs add value not in the path to no_grad_set, ''' inputs_set = set(inputs) + for input in inputs: + if not input.use_empty(): + for used_op in input.all_used_ops(): + for item in get_real_op_inputs(used_op): + inputs_set.add(item) + if inputs_set: for op in block.ops: if some_in_set(get_real_op_inputs(op), inputs_set): @@ -465,12 +475,9 @@ def make_output_with_output_grad(op): ): append_full_like(0.0, input, input, state, backward_ops) output_grads.append(state.value_to_valuegrad[input][0]) - return zero_flag, outputs, output_grads - def make_input_with_input_stopgradient(op): - inputs = [] - input_grad_stopgradients = [] + def get_grad_semantic_info(op): if op.name() in [ "builtin.combine", "pd_op.if", @@ -482,9 +489,13 @@ def make_input_with_input_stopgradient(op): ] else: grad_semantic_info = op.get_input_grad_semantics() + return grad_semantic_info + def make_input_with_input_stopgradient(op): + inputs = [] + input_grad_stopgradients = [] for input, grad_semantic in zip( - get_real_op_inputs(op), grad_semantic_info + get_real_op_inputs(op), get_grad_semantic_info(op) ): if not grad_semantic: if ( @@ -528,7 +539,8 @@ def make_input_with_input_stopgradient(op): else [input] ) inputs.append(tmp_input) - if input.get_defining_op() is None or input in no_grad_set: + + if input in no_grad_set or input.stop_gradient is True: input_grad_stopgradients.append([True]) else: input_grad_stopgradients.append([False]) @@ -537,15 +549,9 @@ def make_input_with_input_stopgradient(op): def update_input_grad_map(op, input_grads, origin_inputs): i = 0 - if ( - op.name() == "builtin.combine" - or op.name() == "pd_op.if" - or op.name() == "pd_op.while" + for input, grad_semantic in zip( + origin_inputs, get_grad_semantic_info(op) ): - grad_semantic_info = [True for _ in range(len(origin_inputs))] - else: - grad_semantic_info = op.get_input_grad_semantics() - for input, grad_semantic in zip(origin_inputs, grad_semantic_info): if not grad_semantic: continue if ( @@ -565,9 +571,13 @@ def update_input_grad_map(op, input_grads, origin_inputs): state.value_to_valuegrad[input].append([input_grad]) i += 1 - def append_yield(block, base_inputs, base_inputs_grad): + def append_yield(block, base_op, base_inputs, base_inputs_grad): with block: inputs_grad = [] + if base_op.name() == "pd_op.while": + new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op) + inputs_grad.append(new_cond) + for value, value_grad in zip(base_inputs, base_inputs_grad): if value_grad is None: continue @@ -595,14 +605,22 @@ def append_yield(block, base_inputs, base_inputs_grad): # -----------------only for control flow-----------------# # tuple_push value to tuple_pop value control_flow_value_to_copyvalue_map = {} + control_flow_copyvalue_to_value_map = {} if ( len(effective_forward_ops) > 1 and effective_forward_ops[-1].name() == "cf.yield" ): yield_op = effective_forward_ops[-1] + if base_op.name() == "pd_op.while": + # while op yield [cond, loop_vars], + # but outputs only has loop_vars. + inside_outputs = yield_op.operands_source()[1:] + else: + inside_outputs = yield_op.operands_source() + for outside_output, inside_output in zip( - base_op.results(), yield_op.operands_source() + base_op.results(), inside_outputs ): state.inside_value_to_outside_value_map[ inside_output @@ -645,7 +663,9 @@ def append_yield(block, base_inputs, base_inputs_grad): control_flow_value_to_copyvalue_map[ output[0] ] = copy_output[0] - + control_flow_copyvalue_to_value_map[ + copy_output[0] + ] = output[0] else: # all(zero_flag) support this op has no contribution for grad # should be delete (prune sub_graph) @@ -725,15 +745,16 @@ def append_yield(block, base_inputs, base_inputs_grad): state.op_to_opgrad[op] = [] if fwd_block != bwd_block: - append_yield(bwd_block, base_inputs, base_input_grads) + append_yield(bwd_block, base_op, base_inputs, base_input_grads) def prepare_backward_prune_set(inputs, outputs): outputs_fwd_set = set() for input_ in inputs: if not input_.use_empty(): - for item in get_real_op_inputs(input_.first_use().owner()): - outputs_fwd_set.add(item) + for used_op in input_.all_used_ops(): + for item in get_real_op_inputs(used_op): + outputs_fwd_set.add(item) else: logging.warning("input privided by inputs has no use") diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 1ddfce6205cce4..851125e004feea 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -671,6 +671,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): args = cur_block.args() next_var = body(*args) next_cond = cond(*next_var) + next_cond.stop_gradient = True cf_yield([next_cond, *next_var]) return while_op.as_operation().results() diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 3dab5f23ca015c..fe447fbde3a342 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -102,6 +102,47 @@ def test_while_base_backward(self): ) print(main_program) + self.assertEqual( + grad_outs[0].get_defining_op().name(), "pd_op.while" + ) + + +def cond2(i, j, ten): + return i < ten + + +def body2(i, j, ten): + i = i + j + return [i, j, ten] + + +class TestBuildModuleWithWhile2Op(unittest.TestCase): + def test_add_n_program(self): + main_program = paddle.static.Program() + with paddle.pir.core.program_guard(main_program): + i = paddle.full( + shape=[1], fill_value=0, dtype='int64' + ) # loop counter + j = paddle.full( + shape=[1], fill_value=2, dtype='int64' + ) # loop counter + ten = paddle.full( + shape=[1], fill_value=10, dtype='int64' + ) # loop length + i.stop_gradient = False + j.stop_gradient = False + i_, j_, ten_ = paddle.static.nn.while_loop( + cond2, body2, [i, j, ten] + ) + out = i_ - j_ + + grad_outs = grad( + out, + [i, j], + ) + + print(main_program) + if __name__ == "__main__": unittest.main() From fde161cde96e4b913fd3437f7f2eaff454902756 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Mon, 18 Dec 2023 06:45:48 +0000 Subject: [PATCH 07/21] modify while grad op --- .../dialect/operator/ir/control_flow_op.cc | 1 + .../pir/dialect/operator/ir/control_flow_op.h | 2 + paddle/fluid/pybind/control_flow_api.cc | 5 +- paddle/fluid/pybind/pir.cc | 6 +- python/paddle/autograd/ir_backward.py | 62 ++++++++++++++++--- test/ir/pir/test_while_api.py | 22 ++++++- 6 files changed, 88 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 6b81f786f1eb17..34ba4ac806e41e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -313,6 +313,7 @@ pir::Block &WhileOp::body() { if (body_region.empty()) body_region.emplace_back(); return body_region.front(); } + pir::Value WhileOp::cond() { return (*this)->operand_source(0); } void WhileOp::Print(pir::IrPrinter &printer) { diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 231da23e289095..96acf0dc68b30b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -16,6 +16,7 @@ #include #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/pir/core/block.h" #include "paddle/pir/core/op_base.h" namespace paddle { @@ -78,6 +79,7 @@ class WhileOp : public pir::Op { const std::vector &inputs); pir::Block &body(); pir::Value cond(); + const pir::Block::ArgListType &block_args() { return body().args(); } void Print(pir::IrPrinter &printer); // NOLINT void VerifySig() {} void VerifyRegion() {} diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index dd815f70332a25..f4abefb729f650 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -89,7 +89,10 @@ void BindWhileOp(py::module* m) { WhileOp in python api. )DOC"); while_op.def("body", &WhileOp::body, return_value_policy::reference) - .def("as_operation", &WhileOp::operation, return_value_policy::reference); + .def("as_operation", &WhileOp::operation, return_value_policy::reference) + .def("block_arguments", + &WhileOp::block_args, + return_value_policy::reference); } void GetUsedExternalValueImpl( diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index bbc92dccfb7c04..c93d69b1781cd3 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -304,12 +304,16 @@ void BindBlock(py::module *m) { "front", [](Block &self) { return &self.front(); }, return_value_policy::reference) + .def_property_readonly( + "parent_op", + [](Block &self) { return self.GetParentOp(); }, + return_value_policy::reference) .def_property_readonly( "program", [](Block &self) { return self.GetParentOp()->GetParentProgram(); }, return_value_policy::reference) .def_property_readonly( - "get_parent", + "parent_block", [](Block &self) { return self.GetParentOp()->GetParent(); }, return_value_policy::reference) .def_property_readonly("ops", diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index b2a593e1310e85..6c153bba9b92a4 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -494,6 +494,8 @@ def get_grad_semantic_info(op): def make_input_with_input_stopgradient(op): inputs = [] input_grad_stopgradients = [] + # if op.name() == "pd_op.add": + # breakpoint() for input, grad_semantic in zip( get_real_op_inputs(op), get_grad_semantic_info(op) ): @@ -565,20 +567,37 @@ def update_input_grad_map(op, input_grads, origin_inputs): ) else: input_grad = input_grads[i] + if input in block_argument_to_value_map: + input = block_argument_to_value_map[input] + if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) else: state.value_to_valuegrad[input].append([input_grad]) i += 1 - def append_yield(block, base_op, base_inputs, base_inputs_grad): + def append_yield( + block, base_op, base_grad_op, base_inputs, base_inputs_grad + ): with block: inputs_grad = [] if base_op.name() == "pd_op.while": new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op) inputs_grad.append(new_cond) - for value, value_grad in zip(base_inputs, base_inputs_grad): + output_grads = base_grad_op.operands_source() + # output_grad = [new_cond, loop_vars(fwd_output_grad)] + # base_inputs = [cond, loop_vars(fwd_input)] + assert len(output_grads) <= len( + base_inputs + ), "while op's inputs size should less than while_grad op's inputs size" + + else: + output_grads = [None] * len(base_inputs) + + for value, value_grad, output_grad in zip( + base_inputs, base_inputs_grad, output_grads + ): if value_grad is None: continue @@ -588,14 +607,34 @@ def append_yield(block, base_op, base_inputs, base_inputs_grad): if value in state.value_to_valuegrad: if len(state.value_to_valuegrad[value]) > 1: append_add_n(value) - inputs_grad.append(state.value_to_valuegrad[value][0][0]) else: value_grad = append_full_like( 0.0, value, value, state, backward_ops ) - inputs_grad.append(value_grad) + + if base_op.name() == "pd_op.while": + input_grad = paddle.add( + output_grad, state.value_to_valuegrad[value][0][0] + ) + else: + input_grad = state.value_to_valuegrad[value][0][0] + + inputs_grad.append(input_grad) + paddle.base.libpaddle.pir.cf_yield(inputs_grad) + def argument_to_value(while_op): + assert len(while_op.as_while_op().block_arguments()) + 1 == len( + while_op.operands_source() + ), "while op's block_arguments size + 1 should same to whiel op's operands_source" + map = {} + for arg, value in zip( + while_op.as_while_op().block_arguments(), + while_op.operands_source()[1:], + ): + map[arg] = value + return map + # there are four patterns: # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) @@ -606,6 +645,7 @@ def append_yield(block, base_op, base_inputs, base_inputs_grad): # tuple_push value to tuple_pop value control_flow_value_to_copyvalue_map = {} control_flow_copyvalue_to_value_map = {} + block_argument_to_value_map = {} if ( len(effective_forward_ops) > 1 @@ -616,6 +656,7 @@ def append_yield(block, base_op, base_inputs, base_inputs_grad): # while op yield [cond, loop_vars], # but outputs only has loop_vars. inside_outputs = yield_op.operands_source()[1:] + block_argument_to_value_map = argument_to_value(base_op) else: inside_outputs = yield_op.operands_source() @@ -677,6 +718,7 @@ def append_yield(block, base_op, base_inputs, base_inputs_grad): for sub_block in op.blocks(): build_pipe_for_block(sub_block) with dynamic_shape_prim_vjp_guard(op, inputs): + # breakpoint() input_grads = paddle.framework.core.call_vjp( op, inputs, @@ -745,7 +787,13 @@ def append_yield(block, base_op, base_inputs, base_inputs_grad): state.op_to_opgrad[op] = [] if fwd_block != bwd_block: - append_yield(bwd_block, base_op, base_inputs, base_input_grads) + append_yield( + bwd_block, + base_op, + bwd_block.parent_op, + base_inputs, + base_input_grads, + ) def prepare_backward_prune_set(inputs, outputs): @@ -830,8 +878,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): inputs_set = set(inputs) outputs_set = set(complete_outputs) total_ops = [] - if block.get_parent is not None: - total_ops += block.get_parent.ops + if block.parent_block is not None: + total_ops += block.parent_block.ops total_ops += block.ops effective_forward_ops, _ = prune_ops( diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index fe447fbde3a342..cd4ab6c0be8e73 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -141,7 +141,27 @@ def test_add_n_program(self): [i, j], ) - print(main_program) + self.assertEqual( + grad_outs[0].get_defining_op().name(), "pd_op.while" + ) + self.assertEqual( + main_program.global_block() + .ops[-1] + .as_while_op() + .body() + .ops[-2] + .name(), + "pd_op.add", + ) + self.assertEqual( + main_program.global_block() + .ops[-1] + .as_while_op() + .body() + .ops[-4] + .name(), + "cf.has_elements", + ) if __name__ == "__main__": From fdc12c7322038cd44780c0cdf206a016f2265846 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Mon, 18 Dec 2023 07:27:14 +0000 Subject: [PATCH 08/21] modify --- paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc | 2 +- python/paddle/autograd/ir_backward.py | 9 +++------ test/ir/pir/test_while_api.py | 1 - 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 34ba4ac806e41e..f4be79bc1c6b8b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -385,7 +385,7 @@ std::vector> WhileOp::Vjp( out_grads.size() + 1, phi::errors::InvalidArgument( "while op's inputs' size should equal to " - "output_grads' size, Now the inputs's size is %d ." + "output_grads' size + 1, Now the inputs's size is %d ." "the output_grads size is %d.", inputs.size(), out_grads.size())); diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 6c153bba9b92a4..bf23988ebf7842 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -81,9 +81,9 @@ def get_real_op_inputs(op): def update_no_grad_set_by_stopgradient(block, no_grad_set): for op in block.ops: - # if op.name() in ["pd_op.if" , "pd_op.while"]: - # for sub_block in op.blocks(): - # update_no_grad_set_by_stopgradient(sub_block, no_grad_set) + if op.name() in ["pd_op.if", "pd_op.while"]: + for sub_block in op.blocks(): + update_no_grad_set_by_stopgradient(sub_block, no_grad_set) for value in op.results(): if value.stop_gradient and value not in no_grad_set: no_grad_set.add(value) @@ -494,8 +494,6 @@ def get_grad_semantic_info(op): def make_input_with_input_stopgradient(op): inputs = [] input_grad_stopgradients = [] - # if op.name() == "pd_op.add": - # breakpoint() for input, grad_semantic in zip( get_real_op_inputs(op), get_grad_semantic_info(op) ): @@ -718,7 +716,6 @@ def argument_to_value(while_op): for sub_block in op.blocks(): build_pipe_for_block(sub_block) with dynamic_shape_prim_vjp_guard(op, inputs): - # breakpoint() input_grads = paddle.framework.core.call_vjp( op, inputs, diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index cd4ab6c0be8e73..eb1f9d3381ed53 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -100,7 +100,6 @@ def test_while_base_backward(self): out, [full_op1.result(0)], ) - print(main_program) self.assertEqual( grad_outs[0].get_defining_op().name(), "pd_op.while" From e3d19b99bffef630b39f1571e12dab28c94d0b4c Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Tue, 19 Dec 2023 02:21:36 +0000 Subject: [PATCH 09/21] increment vp --- .../fluid/pir/dialect/op_generator/op_gen.py | 2 + .../pir/dialect/operator/ir/manual_op.cc | 244 ++++++++++++++++++ .../fluid/pir/dialect/operator/ir/manual_op.h | 85 ++++++ .../pir/dialect/operator/ir/manual_op_vjp.cc | 65 +++++ test/legacy_test/test_while_loop_op.py | 104 +++++--- test/legacy_test/test_while_op.py | 50 ++-- 6 files changed, 488 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 0d83da9bb74239..584ca52ca2dba5 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -238,6 +238,8 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ 'add_n_with_kernel', 'split_grad', 'expand', + 'increment', + 'increment_', } attr_types_map = { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 2160e56442d465..7fe6329d37aca7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2270,6 +2270,248 @@ void SelectInputOp::VerifySig() { VLOG(4) << "End Verifying for: AssignArray_Op."; } +const char *IncrementOp::attributes_name[1] = {"value"}; + +OpInfoTuple IncrementOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + paddle::dialect::OpRunTimeInfo("IncrementInferMeta", + {"x", "value"}, + "increment", + {"x", "value"}, + {}, + {}, + {}, + {}); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "increment"); +} + +void IncrementOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + float value) { + VLOG(4) << "Start build IncrementOp"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void IncrementOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::AttributeMap attributes) { + VLOG(4) << "Start build IncrementOp"; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void IncrementOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: IncrementOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + IR_ENFORCE(input_size == 1u, + "The size %d of inputs must be equal to 1.", + input_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, got %s.", + (*this)->operand_source(0).type()); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + IR_ENFORCE(attributes.count("value") > 0, "value does not exist."); + IR_ENFORCE(attributes.at("value").isa(), + "Type of attribute: value is not pir::FloatAttribute."); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + IR_ENFORCE(output_size == 1u, + "The size %d of outputs must be equal to 1.", + output_size); + IR_ENFORCE( + (*this)->result(0).type().isa(), + "Type validation failed for the 0th output."); + } + VLOG(4) << "End Verifying for: IncrementOp."; +} + +void IncrementOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::IncrementInferMeta); + fn(infer_meta); +} + +phi::DataType IncrementOp::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: IncrementOp"; + + return expected_kernel_dtype; +} + +const char *Increment_Op::attributes_name[1] = {"value"}; + +OpInfoTuple Increment_Op::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + paddle::dialect::OpRunTimeInfo("IncrementInferMeta", + {"x", "value"}, + "increment", + {"x", "value"}, + {}, + {}, + {{"out", "x"}}, + {}); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "increment"); +} + +void Increment_Op::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + float value) { + VLOG(4) << "Start build Increment_Op"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + } // namespace dialect } // namespace paddle @@ -2289,4 +2531,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) #endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 460356039d84ab..50c2f2649d12c6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -412,6 +412,89 @@ class SelectInputOp : public pir::Op { pir::OpResult out() { return result(0); } }; +class IncrementOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.increment"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + float value = 1.0); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + pir::AttributeMap attributes); + + void VerifySig(); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + pir::Value x() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); +}; + +class Increment_Op + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.increment_"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + float value = 1.0); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + pir::AttributeMap attributes); + + void VerifySig(); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + pir::Value x() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); +}; + } // namespace dialect } // namespace paddle @@ -431,3 +514,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index a37cbd681d185d..436d138a891605 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" @@ -127,5 +128,69 @@ std::vector> ExpandOp::Vjp( return res; } +std::vector> IncrementOp::Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1, + platform::errors::InvalidArgument( + "Increment op's inputs size should be 2, but now is %d.", + inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + platform::errors::InvalidArgument( + "Increment op's outputs size should be 1, but now is %d.", + outputs.size())); + + VLOG(6) << "Vjp prepare Prepare attributes of increment_grad"; + + float value = op->attribute("value").dyn_cast().data(); + + VLOG(6) << "Vjp prepare call increment's vjp inteface"; + + pir::OpResult tensor_res = paddle::dialect::increment(inputs_[0][0], -value); + + std::vector> res{{tensor_res}}; + + return res; +} + +std::vector> Increment_Op::Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1, + platform::errors::InvalidArgument( + "Increment_ op's inputs size should be 2, but now is %d.", + inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + platform::errors::InvalidArgument( + "Increment_ op's outputs size should be 1, but now is %d.", + outputs.size())); + + VLOG(6) << "Vjp prepare Prepare attributes of increment__grad"; + + float value = op->attribute("value").dyn_cast().data(); + + VLOG(6) << "Vjp prepare call increment_'s vjp inteface"; + + pir::OpResult tensor_res = paddle::dialect::increment_(inputs_[0][0], -value); + + std::vector> res{{tensor_res}}; + + return res; +} + } // namespace dialect } // namespace paddle diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 05cfdc53cfb9be..91e28ff9fd6a30 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -22,7 +22,8 @@ from paddle import base from paddle.base import core from paddle.base.backward import append_backward -from paddle.base.framework import Program, program_guard +from paddle.base.framework import program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -36,8 +37,8 @@ def cond(i): def body(i): return paddle.add(x=i, y=one) - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0) one = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1) @@ -67,8 +68,8 @@ def body(i, mem): i = paddle.increment(i) return [i, mem] - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') ten = paddle.tensor.fill_constant( @@ -113,8 +114,8 @@ def body(i, ten, test_dict, test_list, test_list_dict): i = paddle.increment(i) return [i, ten, test_dict, test_list, test_list_dict] - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') ten = paddle.tensor.fill_constant( @@ -204,8 +205,8 @@ def internal_body(j, init, sums): i = paddle.increment(i) return [i, j, init, sums] - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') j = paddle.zeros(shape=[1], dtype='int64') @@ -251,6 +252,7 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir + @test_with_pir_api def test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -260,9 +262,9 @@ def body(i, x): i = paddle.increment(i) return [i, x] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False eleven = paddle.tensor.fill_constant( @@ -276,7 +278,8 @@ def body(i, x): out = paddle.static.nn.while_loop(cond, body, [i, x]) mean = paddle.mean(out[1]) - append_backward(mean) + grad_list = append_backward(mean) + print(main_program) place = ( base.CUDAPlace(0) @@ -290,16 +293,27 @@ def body(i, x): data = np.asarray([100]).astype('float32') i_grad = np.asarray([110]).astype('float32') - res = exe.run( - main_program, - feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name], - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == i: + di = g + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean, di], + ) + else: + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean.name, i.grad_name, x.grad_name], + ) np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) + print("res[2]: ", res[2]) # TODO(zhangbo): Support while grad exe for pir - def test_while_loop_backward2(self): + def _test_while_loop_backward2(self): def cond(i, x): return i < 3 @@ -308,9 +322,9 @@ def body(i, x): i = i + 1 return [i, x] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False x = paddle.static.data(name='x', shape=[1], dtype='float32') @@ -318,7 +332,7 @@ def body(i, x): out = paddle.static.nn.while_loop(cond, body, [i, x]) mean = paddle.mean(out[1]) - append_backward(mean) + grad_list = append_backward(mean) place = ( base.CUDAPlace(0) @@ -333,11 +347,23 @@ def body(i, x): i_grad = np.asarray([3]).astype('float32') x_grad = np.asarray([2]).astype('float32') - res = exe.run( - main_program, - feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name, x.grad_name], - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == i: + di = g + if p == x: + dx = g + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean, di, dx], + ) + else: + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean.name, i.grad_name, x.grad_name], + ) np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05) @@ -345,6 +371,7 @@ def body(i, x): class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir + @test_with_pir_api def test_nested_net_with_backward_and_lodtensor(self): def external_cond(i, j, x, mem_array): return paddle.less_than(i, array_len) @@ -373,9 +400,9 @@ def internal_body(j, x, mem_array): ) return [i, j, x, mem_array] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): d0 = paddle.static.data(name='d0', shape=[10], dtype='float32') d1 = paddle.static.data(name='d1', shape=[10], dtype='float32') d2 = paddle.static.data(name='d2', shape=[10], dtype='float32') @@ -409,6 +436,7 @@ def internal_body(j, x, mem_array): sum_result = paddle.tensor.array_read(array=mem_array, i=j) mean = paddle.mean(sum_result) append_backward(mean) + print(main_program) place = ( base.CUDAPlace(0) @@ -457,9 +485,9 @@ def fn_add_one(): default=fn_add_one, ) - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1) ten = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=10 @@ -528,8 +556,8 @@ def body_returns_with_mutable_list(i, test_list): ) return paddle.increment(i), test_list - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): data = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=1 @@ -662,8 +690,8 @@ def body(z, i): i += 1 return z, i - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32') z = paddle.tensor.fill_constant([], 'int32', 0) diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 5ff7698b6b2bc1..6cf13963ec4337 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -23,6 +23,7 @@ from paddle.base.backward import append_backward from paddle.base.executor import Executor from paddle.incubate.layers.nn import shuffle_batch +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -57,7 +58,7 @@ def simple_net(self): cond2 = paddle.less_than(x=j, y=array_len2) while_op = paddle.static.nn.control_flow.While(cond=cond) while_op2 = paddle.static.nn.control_flow.While(cond=cond2) - with while_op.block(): + with while_op.body(): d = paddle.tensor.array_read(array=data_array, i=i) prev = paddle.tensor.array_read(array=mem_array, i=i) result = paddle.add_n([d, prev]) @@ -65,7 +66,7 @@ def simple_net(self): i = paddle.increment(x=i) paddle.tensor.array_write(result, i=i, array=mem_array) - with while_op2.block(): + with while_op2.body(): d2 = paddle.tensor.array_read(array=data_array, i=j) prev2 = paddle.tensor.array_read(array=mem_array, i=j) result2 = paddle.add_n([d2, prev2]) @@ -80,10 +81,10 @@ def simple_net(self): return loss, sum_result # TODO(zhangbo): Support pir test(support write_to_array and read_from_array, support while_grad). - def test_simple_net(self): - main_program = base.Program() - startup_program = base.Program() - with base.program_guard(main_program, startup_program): + def _test_simple_net(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): loss, sum_result = self.simple_net() append_backward(loss) @@ -103,20 +104,21 @@ def test_simple_net(self): # TODO(zhangbo): Support pir test(support write_to_array and read_from_array) def test_simple_net_forward(self): - main_program = base.Program() - startup_program = base.Program() - with base.program_guard(main_program, startup_program): - self.simple_net() - binary = base.compiler.CompiledProgram(main_program) - cpu = core.CPUPlace() - exe = Executor(cpu) - d = [] - - for i in range(3): - d.append(numpy.random.random(size=[10]).astype('float32')) - - for _ in range(2): - exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + self.simple_net() + binary = base.compiler.CompiledProgram(main_program) + cpu = core.CPUPlace() + exe = Executor(cpu) + d = [] + + for i in range(3): + d.append(numpy.random.random(size=[10]).astype('float32')) + + for _ in range(2): + exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) @compare_legacy_with_pt def test_exceptions(self): @@ -135,7 +137,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): @compare_legacy_with_pt def test_error(self): - with base.program_guard(base.Program()): + with paddle.static.program_guard(paddle.static.Program()): def test_bad_x(): x = [1, 2, 3] @@ -194,9 +196,9 @@ def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. """ - main_program = base.Program() - startup_program = base.Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): def func(x): s = paddle.zeros([]) From 0913436d8cc9b0091e5ac6ed0efde2ceee45011b Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Tue, 19 Dec 2023 13:21:24 +0000 Subject: [PATCH 10/21] [PIR] add get_used_external_value interface for block. --- paddle/fluid/pybind/control_flow_api.cc | 15 ++++++++++++++- test/ir/pir/test_while_api.py | 7 ++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index f4abefb729f650..7b564d1a3b04b2 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -130,6 +130,16 @@ std::vector GetUsedExternalValue(const Operation& op) { return used_values; } +std::vector GetUsedExternalValue(const Block& block) { + auto& args = block.args(); + std::unordered_set defined_values(args.begin(), args.end()); + std::vector used_values; + for (auto& op : block) { + GetUsedExternalValueImpl(defined_values, used_values, op); + } + return used_values; +} + Value BuildHasElementsOp(Operation& fwd_op) { // NOLINT PADDLE_ENFORCE(fwd_op.isa(), phi::errors::PreconditionNotMet( @@ -219,7 +229,10 @@ void PyIfOp::UpdateOutput() { } void BindControlFlowApi(py::module* m) { - m->def("get_used_external_value", GetUsedExternalValue); + m->def("get_used_external_value", + [](const Operation& op) { return GetUsedExternalValue(op); }); + m->def("get_used_external_value", + [](const Block& block) { return GetUsedExternalValue(block); }); m->def("build_pipe_for_block", BuildPipeForBlock); m->def("cf_has_elements", BuildHasElementsOp); m->def("cf_yield", [](py::list inputs) { diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index eb1f9d3381ed53..fe5e2a923a9860 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -63,12 +63,13 @@ def test_while_op_vjp_interface(self): main_program = self.construct_program_with_while() while_op = main_program.global_block().ops[-1] self.assertEqual(while_op.name(), "pd_op.while") - build_pipe_for_block(while_op.as_while_op().body()) + body_block = while_op.as_while_op().body() + build_pipe_for_block(body_block) with paddle.pir.core.program_guard(main_program): out_grad = paddle.full(shape=[6, 1], dtype='float32', fill_value=3) # check vjp interface for while_op - while_input = [ - [input] for input in get_used_external_value(while_op) + while_input = [[input] for input in while_op.operands_source()] + [ + [input] for input in get_used_external_value(body_block) ] self.assertEqual(len(while_input), 4) while_input_stop_graditents = [[True], [False], [True], [True]] From 63344b71d2693ab0048162e7c9edb914d936cd2e Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Wed, 20 Dec 2023 06:51:23 +0000 Subject: [PATCH 11/21] while case --- .../dialect/op_generator/op_interface_gen.py | 4 +- .../pir/dialect/operator/interface/vjp.cc | 20 +--- .../pir/dialect/operator/interface/vjp.h | 15 +-- .../dialect/operator/ir/control_flow_op.cc | 6 +- .../pir/dialect/operator/ir/control_flow_op.h | 6 +- .../pir/dialect/operator/ir/manual_op.cc | 101 ++++++++++++++++++ .../fluid/pir/dialect/operator/ir/manual_op.h | 8 +- .../pir/dialect/operator/ir/manual_op_vjp.cc | 8 +- paddle/fluid/pybind/pybind.cc | 4 +- python/paddle/autograd/ir_backward.py | 70 ++++++++---- test/cpp/prim/test_vjp.cc | 19 ++-- test/legacy_test/test_while_loop_op.py | 28 +++-- test/legacy_test/test_while_op.py | 1 - 13 files changed, 203 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 7d3476b1a65204..2a68a1ad430675 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -88,7 +88,7 @@ }""" OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ +std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ {check_param} VLOG(6) << "Prepare inputs of {op_grad_name}"; {backward_input_code} @@ -302,5 +302,5 @@ def gen_exclusive_interface_str(op_info, op_info_items): " static void InferMeta( phi::InferMetaContext *infer_meta );" ) if op_info.op_phi_name[0] not in vjp_interface_black_list: - exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" + exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.cc b/paddle/fluid/pir/dialect/operator/interface/vjp.cc index 5a509b3dfc99e2..ea7854670449ea 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.cc +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.cc @@ -14,24 +14,6 @@ #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" -namespace paddle::dialect { -std::vector> VjpInterface::Vjp( - pir::Operation* op, - const std::vector>& inputs, - const std::vector>& outputs, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - std::vector> out_grads_value; - for (const auto& grad : out_grads) { - std::vector grad_value; - for (auto op_result : grad) { - grad_value.emplace_back(op_result); - } - out_grads_value.emplace_back(std::move(grad_value)); - } - return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); -} - -} // namespace paddle::dialect +namespace paddle::dialect {} // namespace paddle::dialect IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.h b/paddle/fluid/pir/dialect/operator/interface/vjp.h index 44d1731359beb5..5246a2867665e4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.h +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.h @@ -23,14 +23,14 @@ class VjpInterface : public pir::OpInterfaceBase { explicit Concept(std::vector> (*vjp)( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients)) : vjp_(vjp) {} std::vector> (*vjp_)( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); }; @@ -40,7 +40,7 @@ class VjpInterface : public pir::OpInterfaceBase { static std::vector> Vjp( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients); @@ -56,19 +56,12 @@ class VjpInterface : public pir::OpInterfaceBase { std::vector> Vjp( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients); } - std::vector> Vjp( - pir::Operation* op, - const std::vector>& inputs, - const std::vector>& outputs, - const std::vector>& out_grads, - const std::vector>& stop_gradients); - private: Concept* impl_; }; diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index f4be79bc1c6b8b..9162a924163362 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -245,7 +245,7 @@ void IfOp::VerifyRegion() { std::vector> IfOp::Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { PADDLE_ENFORCE_EQ( @@ -345,7 +345,7 @@ void WhileOp::Print(pir::IrPrinter &printer) { std::vector> WhileOp::Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { auto fwd_op = WhileOp::dyn_cast(op); @@ -416,7 +416,7 @@ std::vector> WhileOp::Vjp( std::vector> TuplePushOpVjpInterfaceModel::Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 96acf0dc68b30b..d6e6ea83fa65cd 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -51,7 +51,7 @@ class IfOp : public pir::Op { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -86,7 +86,7 @@ class WhileOp : public pir::Op { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -95,7 +95,7 @@ struct TuplePushOpVjpInterfaceModel : public VjpInterface::Concept { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 7fe6329d37aca7..5e24cfd66ffa36 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2512,6 +2512,107 @@ void Increment_Op::Build(pir::Builder &builder, ::pir::PassStopGradientsDefaultly(argument); } +void Increment_Op::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::AttributeMap attributes) { + VLOG(4) << "Start build Increment_Op"; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for Increment_Op. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void Increment_Op::VerifySig() { + VLOG(4) + << "Start Verifying inputs, outputs and attributes for: Increment_Op."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + IR_ENFORCE(input_size == 1u, + "The size %d of inputs must be equal to 1.", + input_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, got %s.", + (*this)->operand_source(0).type()); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + IR_ENFORCE(attributes.count("value") > 0, "value does not exist."); + IR_ENFORCE(attributes.at("value").isa(), + "Type of attribute: value is not pir::FloatAttribute."); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + IR_ENFORCE(output_size == 1u, + "The size %d of outputs must be equal to 1.", + output_size); + IR_ENFORCE( + (*this)->result(0).type().isa(), + "Type validation failed for the 0th output."); + } + VLOG(4) << "End Verifying for: Increment_Op."; +} + +void Increment_Op::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::IncrementInferMeta); + fn(infer_meta); +} + +phi::DataType Increment_Op::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: Increment_Op"; + + return expected_kernel_dtype; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 50c2f2649d12c6..46914ea3dbb625 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -55,7 +55,7 @@ class AddNOp : public pir::Op> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); static std::vector> Decomp(pir::Operation *op); @@ -396,7 +396,7 @@ class ExpandOp : public pir::Op> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -448,7 +448,7 @@ class IncrementOp static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -490,7 +490,7 @@ class Increment_Op static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index 436d138a891605..f35ab01117d2a3 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -31,7 +31,7 @@ using IntArray = paddle::experimental::IntArray; std::vector> AddNOp::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { VLOG(6) << "Prepare inputs of add_n_grad"; @@ -83,7 +83,7 @@ std::vector> AddNOp::Vjp( std::vector> ExpandOp::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { PADDLE_ENFORCE_EQ(inputs_.size(), @@ -131,7 +131,7 @@ std::vector> ExpandOp::Vjp( std::vector> IncrementOp::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { PADDLE_ENFORCE_EQ( @@ -163,7 +163,7 @@ std::vector> IncrementOp::Vjp( std::vector> Increment_Op::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2221fba3754e0b..efeeb4855205e2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -702,8 +702,8 @@ void BindVjp(pybind11::module *m) { "call_vjp", [](pir::Operation &fwd_op, const std::vector> &inputs, - const std::vector> &outputs, - const std::vector> &out_grads, + const std::vector> &outputs, + const std::vector> &out_grads, const std::vector> &stop_gradients) { py::list res; paddle::dialect::VjpInterface vjp_interface = diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 10650501066bff..f20af2cfe439e1 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -352,6 +352,7 @@ def append_backward_ops( no_grad_set, backward_ops, state, + bwd_block_argument_to_value_map, ): ''' add grad_op in order of topological inverse sort @@ -464,7 +465,12 @@ def make_output_with_output_grad(op): zero_flag[i] = True outputs.append(new_value) - output_grads.append(state.value_to_valuegrad[value][0]) + grad_value = state.value_to_valuegrad[value][0] + output_grads.append( + bwd_block_argument_to_value_map[grad_value[0]] + if grad_value[0] in bwd_block_argument_to_value_map + else grad_value + ) if op.name() == "pd_op.while": for i, input in enumerate(get_real_op_inputs(op)): @@ -481,7 +487,13 @@ def make_output_with_output_grad(op): or state.value_to_valuegrad[input] == [] ): append_full_like(0.0, input, input, state, backward_ops) - output_grads.append(state.value_to_valuegrad[input][0]) + + grad_value = state.value_to_valuegrad[input][0] + output_grads.append( + bwd_block_argument_to_value_map[grad_value[0]] + if grad_value[0] in bwd_block_argument_to_value_map + else grad_value + ) return zero_flag, outputs, output_grads def get_grad_semantic_info(op): @@ -554,10 +566,10 @@ def make_input_with_input_stopgradient(op): return inputs, input_grad_stopgradients - def update_input_grad_map(op, input_grads, origin_inputs): + def update_input_grad_map(op, input_grads, origin_inputs, external_inputs): i = 0 for input, grad_semantic in zip( - origin_inputs, get_grad_semantic_info(op) + origin_inputs, get_grad_semantic_info(op)[: len(origin_inputs) + 1] ): if not grad_semantic: continue @@ -572,8 +584,8 @@ def update_input_grad_map(op, input_grads, origin_inputs): ) else: input_grad = input_grads[i] - if input in block_argument_to_value_map: - input = block_argument_to_value_map[input] + if input in fwd_block_argument_to_value_map: + input = fwd_block_argument_to_value_map[input] if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) @@ -617,12 +629,12 @@ def append_yield( 0.0, value, value, state, backward_ops ) - if base_op.name() == "pd_op.while": - input_grad = paddle.add( - output_grad, state.value_to_valuegrad[value][0][0] - ) - else: - input_grad = state.value_to_valuegrad[value][0][0] + # if base_op.name() == "pd_op.while": + # input_grad = paddle.add( + # output_grad, state.value_to_valuegrad[value][0][0] + # ) + # else: + input_grad = state.value_to_valuegrad[value][0][0] inputs_grad.append(input_grad) @@ -632,13 +644,15 @@ def argument_to_value(while_op): assert len(while_op.as_while_op().block_arguments()) + 1 == len( while_op.operands_source() ), "while op's block_arguments size + 1 should same to whiel op's operands_source" - map = ValueDict() + arg_to_value_map = ValueDict() + value_to_arg_map = ValueDict() for arg, value in zip( while_op.as_while_op().block_arguments(), while_op.operands_source()[1:], ): - map[arg] = value - return map + arg_to_value_map[arg] = value + value_to_arg_map[value] = [arg] + return arg_to_value_map, value_to_arg_map # there are four patterns: # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) @@ -650,7 +664,9 @@ def argument_to_value(while_op): # tuple_push value to pop value control_flow_value_to_copyvalue_map = ValueDict() control_flow_copyvalue_to_value_map = ValueDict() - block_argument_to_value_map = ValueDict() + # fwd_whileop's blockargument to fwd_whileop's input value + fwd_block_argument_to_value_map = ValueDict() + # bwd_whileop's input value to bwd_whileop's blockargument if ( len(effective_forward_ops) > 1 @@ -661,7 +677,7 @@ def argument_to_value(while_op): # while op yield [cond, loop_vars], # but outputs only has loop_vars. inside_outputs = yield_op.operands_source()[1:] - block_argument_to_value_map = argument_to_value(base_op) + fwd_block_argument_to_value_map, _ = argument_to_value(base_op) else: inside_outputs = yield_op.operands_source() @@ -694,7 +710,7 @@ def argument_to_value(while_op): input_grad_stopgradients, ) = make_input_with_input_stopgradient(op) - if op.name() == "cf.tuple_push": + if op.name() in ["cf.tuple_push", "pd_op.increment_"]: with dynamic_shape_prim_vjp_guard(op, inputs): copy_out = paddle.framework.core.call_vjp( op, @@ -718,7 +734,7 @@ def argument_to_value(while_op): if len(output_grads) == 0 or all(zero_flag): continue - if op.name() == "pd_op.if" or op.name() == "pd_op.while": + if op.name() in ["pd_op.if", "pd_op.while"]: origin_inputs = get_used_external_value(op) for sub_block in op.blocks(): build_pipe_for_block(sub_block) @@ -736,8 +752,15 @@ def argument_to_value(while_op): for sub_fwd_block, sub_bwd_block in zip( op.blocks(), grad_op.blocks() ): + # update grad_op structure + if grad_op.name() == "pd_op.while": + ( + _, + sub_bwd_block_argument_to_value_map, + ) = argument_to_value(grad_op) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] + breakpoint() append_backward_ops( op, [input[0] for input in inputs], @@ -748,9 +771,12 @@ def argument_to_value(while_op): no_grad_set, sub_backward_ops, sub_state, + sub_bwd_block_argument_to_value_map, ) # update input_grad map - update_input_grad_map(op, input_grads, origin_inputs) + update_input_grad_map( + op, input_grads, op.operands_source(), [] + ) else: # create grad_op before_ops_num = len(bwd_block.ops) @@ -772,7 +798,7 @@ def argument_to_value(while_op): # update input_grad map update_input_grad_map( - op, input_grads, op.operands_source() + op, input_grads, op.operands_source(), [] ) update_bwdop_structure( @@ -907,6 +933,7 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): no_grad_set, backward_ops, state, + ValueDict(), ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( @@ -915,7 +942,6 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) - state.turn_map() for bwd_op in inverse_sort_op(remove_ops): diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 2c3ba7073e6065..f9393bf6b9f548 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -70,7 +70,7 @@ TEST(VJP, TanhBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh"); @@ -125,7 +125,7 @@ TEST(VJP, Tanh_BackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh_"); @@ -181,7 +181,7 @@ TEST(VJP, MeanBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.mean"); @@ -241,7 +241,7 @@ TEST(VJP, ConcatBackwardTest) { std::vector> stop_gradients{{false, false}}; std::vector> inputs{{op1.out(), op1.out()}, {op3.axis()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.concat"); auto concat_vjp_interface_impl = @@ -308,7 +308,7 @@ TEST(VJP, AddBackwardTest) { std::vector> stop_gradients{{false}, {false}}; std::vector> inputs{{op1.out()}, {op2.out()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add"); @@ -373,7 +373,7 @@ TEST(VJP, Add_BackwardTest) { std::vector> stop_gradients{{false}, {false}}; std::vector> inputs{{op1.out()}, {op2.out()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add_"); @@ -441,7 +441,12 @@ TEST(VJP, SplitBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{ {op2.x()}, {op2.sections()}, {op2.axis()}}; - std::vector> outputs{{op3.outputs()}}; + std::vector> outputs(1); + std::vector res; + for (uint32_t i = 0; i < op3.outputs().size(); i++) { + res.push_back(op3.outputs()[i]); + } + outputs[0] = res; std::vector> out_grads{{op3.result(0), op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split"); diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 91e28ff9fd6a30..ce24bc2b4c80eb 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -253,7 +253,7 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir @test_with_pir_api - def test_while_loop_backward(self): + def _test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -313,8 +313,12 @@ def body(i, x): print("res[2]: ", res[2]) # TODO(zhangbo): Support while grad exe for pir - def _test_while_loop_backward2(self): - def cond(i, x): + @test_with_pir_api + def test_while_loop_backward2(self): + def cond1(i, x): + return i < 2 + + def cond2(i, x): return i < 3 def body(i, x): @@ -327,12 +331,15 @@ def body(i, x): with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False + i.persistable = True x = paddle.static.data(name='x', shape=[1], dtype='float32') x.stop_gradient = False + x.persistable = True - out = paddle.static.nn.while_loop(cond, body, [i, x]) + out = paddle.static.nn.while_loop(cond1, body, [i, x]) mean = paddle.mean(out[1]) grad_list = append_backward(mean) + print(main_program) place = ( base.CUDAPlace(0) @@ -344,6 +351,8 @@ def body(i, x): feed_i = np.ones(1).astype('float32') feed_x = np.ones(1).astype('float32') data = np.asarray([2]).astype('float32') + ans = np.asarray([1]).astype('float32') + x1_grad = np.asarray([1]).astype('float32') i_grad = np.asarray([3]).astype('float32') x_grad = np.asarray([2]).astype('float32') @@ -356,17 +365,18 @@ def body(i, x): res = exe.run( main_program, feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean, di, dx], + fetch_list=[out[1], di, dx], ) else: res = exe.run( main_program, feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name, x.grad_name], + fetch_list=[out[1].name, i.grad_name, x.grad_name], ) - np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) - np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) - np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05) + print(res[0], res[1], res[1]) + np.testing.assert_allclose(np.asarray(res[0]), ans, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[1]), ans, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[2]), ans, rtol=1e-05) class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 6cf13963ec4337..30607e6bb7a38d 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -23,7 +23,6 @@ from paddle.base.backward import append_backward from paddle.base.executor import Executor from paddle.incubate.layers.nn import shuffle_batch -from paddle.pir_utils import test_with_pir_api paddle.enable_static() From 59ad2fc638e977d4b49352fb04ecdd39a8ab1980 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Wed, 20 Dec 2023 07:12:02 +0000 Subject: [PATCH 12/21] delete print --- test/legacy_test/test_while_loop_op.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index ce24bc2b4c80eb..54a47d2532835c 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -252,8 +252,8 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir - @test_with_pir_api - def _test_while_loop_backward(self): + + def test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -279,7 +279,6 @@ def body(i, x): out = paddle.static.nn.while_loop(cond, body, [i, x]) mean = paddle.mean(out[1]) grad_list = append_backward(mean) - print(main_program) place = ( base.CUDAPlace(0) @@ -310,7 +309,6 @@ def body(i, x): ) np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) - print("res[2]: ", res[2]) # TODO(zhangbo): Support while grad exe for pir @test_with_pir_api @@ -373,7 +371,7 @@ def body(i, x): feed={'i': feed_i, 'x': feed_x}, fetch_list=[out[1].name, i.grad_name, x.grad_name], ) - print(res[0], res[1], res[1]) + np.testing.assert_allclose(np.asarray(res[0]), ans, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), ans, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[2]), ans, rtol=1e-05) @@ -381,7 +379,7 @@ def body(i, x): class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir - @test_with_pir_api + def test_nested_net_with_backward_and_lodtensor(self): def external_cond(i, j, x, mem_array): return paddle.less_than(i, array_len) @@ -446,7 +444,6 @@ def internal_body(j, x, mem_array): sum_result = paddle.tensor.array_read(array=mem_array, i=j) mean = paddle.mean(sum_result) append_backward(mean) - print(main_program) place = ( base.CUDAPlace(0) From f4eceb63e8ec0f6105a1bf2e179e7daaad995f5a Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Wed, 20 Dec 2023 07:13:55 +0000 Subject: [PATCH 13/21] delete print --- test/legacy_test/test_while_loop_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 54a47d2532835c..38bc7bf037dcc2 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -337,7 +337,6 @@ def body(i, x): out = paddle.static.nn.while_loop(cond1, body, [i, x]) mean = paddle.mean(out[1]) grad_list = append_backward(mean) - print(main_program) place = ( base.CUDAPlace(0) From 1c9eb96dafc3b75335368ddf687bbd79dbb8019b Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:24:04 +0800 Subject: [PATCH 14/21] Update python/paddle/autograd/ir_backward.py --- python/paddle/autograd/ir_backward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index f20af2cfe439e1..e6fbdcfaecb651 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -760,7 +760,6 @@ def argument_to_value(while_op): ) = argument_to_value(grad_op) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] - breakpoint() append_backward_ops( op, [input[0] for input in inputs], From df0b46aeff2c287e5bc44096073bb5a17e8df789 Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Wed, 20 Dec 2023 09:07:41 +0000 Subject: [PATCH 15/21] [PIR] add unit_test for get_used_external_value --- python/paddle/tensor/logic.py | 4 ++-- test/ir/pir/test_while_api.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 31b38c13ff57d5..32796a73c965fa 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -533,11 +533,11 @@ def equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [True , False, False]) """ - if not isinstance(y, (int, bool, float, Variable, paddle.pir.OpResult)): + if not isinstance(y, (int, bool, float, Variable, paddle.pir.Value)): raise TypeError( f"Type of input args must be float, bool, int or Tensor, but received type {type(y)}" ) - if not isinstance(y, (Variable, paddle.pir.OpResult)): + if not isinstance(y, (Variable, paddle.pir.Value)): y = full(shape=[], dtype=x.dtype, fill_value=y) if in_dynamic_or_pir_mode(): diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index fe5e2a923a9860..c067b4d174f14d 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -59,6 +59,41 @@ def test_while_base(self): self.assertEqual(last_op.name(), "pd_op.while") self.assertEqual(len(out), 2) + def test_get_used_external_value(self): + main_program = paddle.static.Program() + with paddle.pir.core.program_guard(main_program): + print(main_program) + i = paddle.full(shape=[1], fill_value=0) + print(main_program) + x = paddle.full(shape=[1], fill_value=10) + y = paddle.full(shape=[1], fill_value=5) + # i, x = paddle.static.nn.while_loop(cond, body, [i, ten]) + paddle.static.nn.while_loop( + lambda p, q: p < q, lambda p, q: [p + y, q + i], [i, x] + ) + print(main_program) + while_op = main_program.global_block().ops[-1] + self.assertEqual(while_op.name(), "pd_op.while") + body_block = while_op.as_while_op().body() + operand_source = while_op.operands_source() + # 【cond, i , x】 + self.assertEqual(len(operand_source), 3) + self.assertTrue(operand_source[1].is_same(i)) + self.assertTrue(operand_source[2].is_same(x)) + + block_external_values = get_used_external_value(body_block) + # 【y, i】 + self.assertEqual(len(block_external_values), 2) + self.assertTrue(block_external_values[0].is_same(y)) + self.assertTrue(block_external_values[1].is_same(i)) + + op_external_values = get_used_external_value(while_op) + # 【cond, i , x, y】 + self.assertEqual(len(op_external_values), 4) + self.assertTrue(op_external_values[1].is_same(i)) + self.assertTrue(op_external_values[2].is_same(x)) + self.assertTrue(op_external_values[3].is_same(y)) + def test_while_op_vjp_interface(self): main_program = self.construct_program_with_while() while_op = main_program.global_block().ops[-1] From 65083dfe720ea1a64fa012122feddcfe0298efbb Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Thu, 21 Dec 2023 02:05:40 +0000 Subject: [PATCH 16/21] modify while_loop --- .../pir/dialect/operator/ir/manual_op.cc | 3 +- .../fluid/pir/dialect/operator/ir/manual_op.h | 8 +- python/paddle/autograd/ir_backward.py | 92 ++++++++++++++----- test/legacy_test/test_while_loop_op.py | 7 +- 4 files changed, 80 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 5e24cfd66ffa36..dad8c36e2f358a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -20,7 +20,8 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, - paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp + paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp, + paddle::dialect::IncrementOp, paddle::dialect::Increment_Op #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 46914ea3dbb625..1f367b4319d8c9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -426,12 +426,12 @@ class IncrementOp static OpInfoTuple GetOpInfo(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value x_, + pir::Value x_, // NOLINT float value = 1.0); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value x_, + pir::Value x_, // NOLINT pir::AttributeMap attributes); void VerifySig(); @@ -468,12 +468,12 @@ class Increment_Op static OpInfoTuple GetOpInfo(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value x_, + pir::Value x_, // NOLINT float value = 1.0); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::Value x_, + pir::Value x_, // NOLINT pir::AttributeMap attributes); void VerifySig(); diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index f20af2cfe439e1..1cf265e776b03e 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -327,7 +327,7 @@ def inverse_sort_op(ops): while queue: op = queue.popleft() sorted_list.append(op) - + # why reverse: tuple_push's input order is fwd op's order for x in get_real_op_inputs(op): x_op = x.get_defining_op() pending_count[x_op] -= 1 @@ -338,6 +338,32 @@ def inverse_sort_op(ops): raise ValueError( "inverse_sort_op wrong, sorted_list size is not equal to origin_list size" ) + change_list = [] + for op in reversed(sorted_list): + print("^^^^", op.name()) + if op.name() == 'pd_op.increment_': + idx_1 = sorted_list.index(op) + idx_2 = sorted_list.index(op) + + for op_in in reversed(sorted_list[: sorted_list.index(op)]): + print("&&&&", op_in.name()) + if ( + some_in_set( + op.operands_source(), + ValueSet(get_real_op_inputs(op_in)), + ) + and op_in.name() != "cf.tuple_push" + ): + idx_2 = sorted_list.index(op_in) + print("$$$$", idx_1, " ", idx_2) + if idx_1 != idx_2: + change_list.append((idx_1, idx_2)) + print("change_list :", change_list) + for idx_1, idx_2 in change_list: + sorted_list[idx_1], sorted_list[idx_2] = ( + sorted_list[idx_2], + sorted_list[idx_1], + ) return sorted_list @@ -394,6 +420,14 @@ def append_backward_ops( else continue to next op. ''' + def return_value_to_copyvalue_map( + value, control_flow_value_to_copyvalue_map + ): + output = value + while output in control_flow_value_to_copyvalue_map: + output = control_flow_value_to_copyvalue_map[output] + return output + def append_add_n(value): # value is input of more than one fwd_op, # so more than one bwd_op create input_grad, @@ -416,11 +450,11 @@ def make_output_with_output_grad(op): outputs = [] output_grads = [] for i, value in enumerate(op.results()): - new_value = ( - [control_flow_value_to_copyvalue_map[value]] - if value in control_flow_value_to_copyvalue_map - else [value] - ) + new_value = [ + return_value_to_copyvalue_map( + value, control_flow_value_to_copyvalue_map + ) + ] while value in state.inside_value_to_outside_value_map: value = state.inside_value_to_outside_value_map[value] @@ -502,6 +536,8 @@ def get_grad_semantic_info(op): "pd_op.if", "pd_op.while", "cf.tuple_push", + "pd_op.increment_", + "pd_op.increment", ]: grad_semantic_info = [ True for _ in range(len(get_real_op_inputs(op))) @@ -524,18 +560,18 @@ def make_input_with_input_stopgradient(op): tmp_input = [] for tmp in input.get_defining_op().operands_source(): tmp_input.append( - control_flow_value_to_copyvalue_map[tmp] - if tmp in control_flow_value_to_copyvalue_map - else tmp + return_value_to_copyvalue_map( + tmp, control_flow_value_to_copyvalue_map + ) ) inputs.append(tmp_input) else: - tmp_input = ( - [control_flow_value_to_copyvalue_map[input]] - if input in control_flow_value_to_copyvalue_map - else [input] - ) + tmp_input = [ + return_value_to_copyvalue_map( + input, control_flow_value_to_copyvalue_map + ) + ] inputs.append(tmp_input) continue @@ -552,11 +588,11 @@ def make_input_with_input_stopgradient(op): [info[0] for info in combine_stop_gradient] ) else: - tmp_input = ( - [control_flow_value_to_copyvalue_map[input]] - if input in control_flow_value_to_copyvalue_map - else [input] - ) + tmp_input = [ + return_value_to_copyvalue_map( + input, control_flow_value_to_copyvalue_map + ) + ] inputs.append(tmp_input) if input in no_grad_set or input.stop_gradient is True: @@ -581,6 +617,7 @@ def update_input_grad_map(op, input_grads, origin_inputs, external_inputs): input.get_defining_op(), input_grads[i], input.get_defining_op().operands_source(), + [], ) else: input_grad = input_grads[i] @@ -697,6 +734,7 @@ def argument_to_value(while_op): if op.name() != "builtin.combine" and op.name() != "builtin.split": clear_effective_forward_ops.append(op) with bwd_block: + print([op.name() for op in clear_effective_forward_ops]) for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): # prepare output_grad @@ -721,7 +759,17 @@ def argument_to_value(while_op): ) pop_op = bwd_block.ops[-1] bwd_ops = [pop_op] - for output, copy_output in zip(inputs[1:], copy_out[1:]): + tmp_inputs = ( + inputs + if op.name() == "pd_op.increment_" + else inputs[1:] + ) + tmp_copy_out = ( + copy_out + if op.name() == "pd_op.increment_" + else copy_out[1:] + ) + for output, copy_output in zip(tmp_inputs, tmp_copy_out): control_flow_value_to_copyvalue_map[ output[0] ] = copy_output[0] @@ -760,7 +808,6 @@ def argument_to_value(while_op): ) = argument_to_value(grad_op) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] - breakpoint() append_backward_ops( op, [input[0] for input in inputs], @@ -939,11 +986,12 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( outputs_fwd_set, inputs_fwd_set, no_grad_set, state ) + _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) - state.turn_map() + state.turn_map() for bwd_op in inverse_sort_op(remove_ops): if bwd_op.result(0) in ValueSet(grad_outputs): continue diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 38bc7bf037dcc2..fecb96245476a0 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -23,7 +23,6 @@ from paddle.base import core from paddle.base.backward import append_backward from paddle.base.framework import program_guard -from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -252,7 +251,7 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir - + # @test_with_pir_api def test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -267,6 +266,7 @@ def body(i, x): with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False + i.persistable = True eleven = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=11 ) @@ -275,6 +275,7 @@ def body(i, x): ) x = paddle.static.data(name='x', shape=[1], dtype='float32') x.stop_gradient = False + x.persistable = True out = paddle.static.nn.while_loop(cond, body, [i, x]) mean = paddle.mean(out[1]) @@ -311,7 +312,7 @@ def body(i, x): np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) # TODO(zhangbo): Support while grad exe for pir - @test_with_pir_api + # @test_with_pir_api def test_while_loop_backward2(self): def cond1(i, x): return i < 2 From 95bc3d7eb04576d773ddc4c9ebe4b38ed531cbd6 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Thu, 21 Dec 2023 02:35:34 +0000 Subject: [PATCH 17/21] code_style --- test/legacy_test/test_while_loop_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index d8e770b29dd98d..4feddf5a7c2df6 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -23,6 +23,7 @@ from paddle.base import core from paddle.base.backward import append_backward from paddle.base.framework import program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() From 37e807c1a6493c033265a52b382cafff5e308fb5 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Thu, 21 Dec 2023 06:13:19 +0000 Subject: [PATCH 18/21] modofy ci bug --- python/paddle/autograd/ir_backward.py | 4 +++ test/ir/pir/test_while_api.py | 2 +- test/legacy_test/test_while_op.py | 49 +++++++++++++-------------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 1cf265e776b03e..820ce11b1615ad 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -806,6 +806,10 @@ def argument_to_value(while_op): _, sub_bwd_block_argument_to_value_map, ) = argument_to_value(grad_op) + else: + sub_bwd_block_argument_to_value_map = ( + ValueDict() + ) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] append_backward_ops( diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index eb1f9d3381ed53..0cb0ef9d43def1 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -150,7 +150,7 @@ def test_add_n_program(self): .body() .ops[-2] .name(), - "pd_op.add", + "cf.has_elements", ) self.assertEqual( main_program.global_block() diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 30607e6bb7a38d..5ff7698b6b2bc1 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -57,7 +57,7 @@ def simple_net(self): cond2 = paddle.less_than(x=j, y=array_len2) while_op = paddle.static.nn.control_flow.While(cond=cond) while_op2 = paddle.static.nn.control_flow.While(cond=cond2) - with while_op.body(): + with while_op.block(): d = paddle.tensor.array_read(array=data_array, i=i) prev = paddle.tensor.array_read(array=mem_array, i=i) result = paddle.add_n([d, prev]) @@ -65,7 +65,7 @@ def simple_net(self): i = paddle.increment(x=i) paddle.tensor.array_write(result, i=i, array=mem_array) - with while_op2.body(): + with while_op2.block(): d2 = paddle.tensor.array_read(array=data_array, i=j) prev2 = paddle.tensor.array_read(array=mem_array, i=j) result2 = paddle.add_n([d2, prev2]) @@ -80,10 +80,10 @@ def simple_net(self): return loss, sum_result # TODO(zhangbo): Support pir test(support write_to_array and read_from_array, support while_grad). - def _test_simple_net(self): - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): + def test_simple_net(self): + main_program = base.Program() + startup_program = base.Program() + with base.program_guard(main_program, startup_program): loss, sum_result = self.simple_net() append_backward(loss) @@ -103,21 +103,20 @@ def _test_simple_net(self): # TODO(zhangbo): Support pir test(support write_to_array and read_from_array) def test_simple_net_forward(self): - with paddle.pir_utils.IrGuard(): - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - self.simple_net() - binary = base.compiler.CompiledProgram(main_program) - cpu = core.CPUPlace() - exe = Executor(cpu) - d = [] - - for i in range(3): - d.append(numpy.random.random(size=[10]).astype('float32')) - - for _ in range(2): - exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + main_program = base.Program() + startup_program = base.Program() + with base.program_guard(main_program, startup_program): + self.simple_net() + binary = base.compiler.CompiledProgram(main_program) + cpu = core.CPUPlace() + exe = Executor(cpu) + d = [] + + for i in range(3): + d.append(numpy.random.random(size=[10]).astype('float32')) + + for _ in range(2): + exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) @compare_legacy_with_pt def test_exceptions(self): @@ -136,7 +135,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): @compare_legacy_with_pt def test_error(self): - with paddle.static.program_guard(paddle.static.Program()): + with base.program_guard(base.Program()): def test_bad_x(): x = [1, 2, 3] @@ -195,9 +194,9 @@ def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. """ - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): + main_program = base.Program() + startup_program = base.Program() + with base.program_guard(main_program, startup_program): def func(x): s = paddle.zeros([]) From 48de1240bc3c258a7d32364a2e1e98255c5fa9a9 Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 22 Dec 2023 07:18:31 +0000 Subject: [PATCH 19/21] modify while api --- python/paddle/autograd/ir_backward.py | 190 +++++++++++++------------- test/ir/pir/test_while_api.py | 39 +++++- 2 files changed, 133 insertions(+), 96 deletions(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 820ce11b1615ad..71d444a3e5c689 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -80,8 +80,12 @@ def append_full_like(float_value, copy_value, value, state, backward_ops): def get_real_op_inputs(op): - if op.name() in ["pd_op.if", "pd_op.while"]: + if op.name() == "pd_op.if": return get_used_external_value(op) + elif op.name() == "pd_op.while": + return op.operands_source() + get_used_external_value( + op.as_while_op().body() + ) else: return op.operands_source() @@ -340,13 +344,11 @@ def inverse_sort_op(ops): ) change_list = [] for op in reversed(sorted_list): - print("^^^^", op.name()) if op.name() == 'pd_op.increment_': idx_1 = sorted_list.index(op) idx_2 = sorted_list.index(op) for op_in in reversed(sorted_list[: sorted_list.index(op)]): - print("&&&&", op_in.name()) if ( some_in_set( op.operands_source(), @@ -355,10 +357,9 @@ def inverse_sort_op(ops): and op_in.name() != "cf.tuple_push" ): idx_2 = sorted_list.index(op_in) - print("$$$$", idx_1, " ", idx_2) if idx_1 != idx_2: change_list.append((idx_1, idx_2)) - print("change_list :", change_list) + for idx_1, idx_2 in change_list: sorted_list[idx_1], sorted_list[idx_2] = ( sorted_list[idx_2], @@ -378,7 +379,7 @@ def append_backward_ops( no_grad_set, backward_ops, state, - bwd_block_argument_to_value_map, + bwd_value_to_block_argument_map=ValueDict(), ): ''' add grad_op in order of topological inverse sort @@ -420,12 +421,10 @@ def append_backward_ops( else continue to next op. ''' - def return_value_to_copyvalue_map( - value, control_flow_value_to_copyvalue_map - ): + def return_map_value(value, map): output = value - while output in control_flow_value_to_copyvalue_map: - output = control_flow_value_to_copyvalue_map[output] + while output in map: + output = map[output] return output def append_add_n(value): @@ -451,9 +450,7 @@ def make_output_with_output_grad(op): output_grads = [] for i, value in enumerate(op.results()): new_value = [ - return_value_to_copyvalue_map( - value, control_flow_value_to_copyvalue_map - ) + return_map_value(value, control_flow_value_to_copyvalue_map) ] while value in state.inside_value_to_outside_value_map: value = state.inside_value_to_outside_value_map[value] @@ -501,33 +498,11 @@ def make_output_with_output_grad(op): outputs.append(new_value) grad_value = state.value_to_valuegrad[value][0] output_grads.append( - bwd_block_argument_to_value_map[grad_value[0]] - if grad_value[0] in bwd_block_argument_to_value_map + [bwd_value_to_block_argument_map[grad_value[0]]] + if grad_value[0] in bwd_value_to_block_argument_map else grad_value ) - if op.name() == "pd_op.while": - for i, input in enumerate(get_real_op_inputs(op)): - if i <= len(op.results()): - continue - if ( - input in state.value_to_valuegrad - and len(state.value_to_valuegrad[input]) > 1 - ): - append_add_n(input) - - if ( - input not in state.value_to_valuegrad - or state.value_to_valuegrad[input] == [] - ): - append_full_like(0.0, input, input, state, backward_ops) - - grad_value = state.value_to_valuegrad[input][0] - output_grads.append( - bwd_block_argument_to_value_map[grad_value[0]] - if grad_value[0] in bwd_block_argument_to_value_map - else grad_value - ) return zero_flag, outputs, output_grads def get_grad_semantic_info(op): @@ -560,7 +535,7 @@ def make_input_with_input_stopgradient(op): tmp_input = [] for tmp in input.get_defining_op().operands_source(): tmp_input.append( - return_value_to_copyvalue_map( + return_map_value( tmp, control_flow_value_to_copyvalue_map ) ) @@ -568,7 +543,7 @@ def make_input_with_input_stopgradient(op): inputs.append(tmp_input) else: tmp_input = [ - return_value_to_copyvalue_map( + return_map_value( input, control_flow_value_to_copyvalue_map ) ] @@ -589,9 +564,7 @@ def make_input_with_input_stopgradient(op): ) else: tmp_input = [ - return_value_to_copyvalue_map( - input, control_flow_value_to_copyvalue_map - ) + return_map_value(input, control_flow_value_to_copyvalue_map) ] inputs.append(tmp_input) @@ -602,13 +575,13 @@ def make_input_with_input_stopgradient(op): return inputs, input_grad_stopgradients - def update_input_grad_map(op, input_grads, origin_inputs, external_inputs): + def update_input_grad_map(op, input_grads, all_inputs): + _, fwd_value_to_block_argument_map = argument_to_value(op) i = 0 - for input, grad_semantic in zip( - origin_inputs, get_grad_semantic_info(op)[: len(origin_inputs) + 1] - ): + for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)): if not grad_semantic: continue + if ( input.get_defining_op() is not None and input.get_defining_op().name() == "builtin.combine" @@ -617,13 +590,9 @@ def update_input_grad_map(op, input_grads, origin_inputs, external_inputs): input.get_defining_op(), input_grads[i], input.get_defining_op().operands_source(), - [], ) else: input_grad = input_grads[i] - if input in fwd_block_argument_to_value_map: - input = fwd_block_argument_to_value_map[input] - if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) else: @@ -631,33 +600,32 @@ def update_input_grad_map(op, input_grads, origin_inputs, external_inputs): i += 1 def append_yield( - block, base_op, base_grad_op, base_inputs, base_inputs_grad + block, + base_op, + base_grad_op, + base_inputs, + base_inputs_grad, ): + ( + fwd_block_argument_to_value_map, + fwd_value_to_block_argument_map, + ) = argument_to_value(base_op) with block: inputs_grad = [] if base_op.name() == "pd_op.while": new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op) inputs_grad.append(new_cond) - output_grads = base_grad_op.operands_source() - # output_grad = [new_cond, loop_vars(fwd_output_grad)] - # base_inputs = [cond, loop_vars(fwd_input)] - assert len(output_grads) <= len( - base_inputs - ), "while op's inputs size should less than while_grad op's inputs size" + for idx in range(len(base_inputs[: base_op.num_operands()])): + operands = base_inputs[idx] + if operands in fwd_value_to_block_argument_map: + operands = fwd_value_to_block_argument_map[operands] + base_inputs[idx] = operands - else: - output_grads = [None] * len(base_inputs) - - for value, value_grad, output_grad in zip( - base_inputs, base_inputs_grad, output_grads - ): + for value, value_grad in zip(base_inputs, base_inputs_grad): if value_grad is None: continue - while value in state.inside_value_to_outside_value_map: - value = state.inside_value_to_outside_value_map[value] - if value in state.value_to_valuegrad: if len(state.value_to_valuegrad[value]) > 1: append_add_n(value) @@ -666,11 +634,6 @@ def append_yield( 0.0, value, value, state, backward_ops ) - # if base_op.name() == "pd_op.while": - # input_grad = paddle.add( - # output_grad, state.value_to_valuegrad[value][0][0] - # ) - # else: input_grad = state.value_to_valuegrad[value][0][0] inputs_grad.append(input_grad) @@ -678,6 +641,9 @@ def append_yield( paddle.base.libpaddle.pir.cf_yield(inputs_grad) def argument_to_value(while_op): + if while_op.name() != "pd_op.while": + return ValueDict(), ValueDict() + assert len(while_op.as_while_op().block_arguments()) + 1 == len( while_op.operands_source() ), "while op's block_arguments size + 1 should same to whiel op's operands_source" @@ -688,7 +654,7 @@ def argument_to_value(while_op): while_op.operands_source()[1:], ): arg_to_value_map[arg] = value - value_to_arg_map[value] = [arg] + value_to_arg_map[value] = arg return arg_to_value_map, value_to_arg_map # there are four patterns: @@ -701,9 +667,6 @@ def argument_to_value(while_op): # tuple_push value to pop value control_flow_value_to_copyvalue_map = ValueDict() control_flow_copyvalue_to_value_map = ValueDict() - # fwd_whileop's blockargument to fwd_whileop's input value - fwd_block_argument_to_value_map = ValueDict() - # bwd_whileop's input value to bwd_whileop's blockargument if ( len(effective_forward_ops) > 1 @@ -714,7 +677,6 @@ def argument_to_value(while_op): # while op yield [cond, loop_vars], # but outputs only has loop_vars. inside_outputs = yield_op.operands_source()[1:] - fwd_block_argument_to_value_map, _ = argument_to_value(base_op) else: inside_outputs = yield_op.operands_source() @@ -734,7 +696,6 @@ def argument_to_value(while_op): if op.name() != "builtin.combine" and op.name() != "builtin.split": clear_effective_forward_ops.append(op) with bwd_block: - print([op.name() for op in clear_effective_forward_ops]) for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): # prepare output_grad @@ -782,8 +743,9 @@ def argument_to_value(while_op): if len(output_grads) == 0 or all(zero_flag): continue - if op.name() in ["pd_op.if", "pd_op.while"]: - origin_inputs = get_used_external_value(op) + if op.name() == "pd_op.if": + origin_inputs = get_real_op_inputs(op) + for sub_block in op.blocks(): build_pipe_for_block(sub_block) with dynamic_shape_prim_vjp_guard(op, inputs): @@ -800,16 +762,6 @@ def argument_to_value(while_op): for sub_fwd_block, sub_bwd_block in zip( op.blocks(), grad_op.blocks() ): - # update grad_op structure - if grad_op.name() == "pd_op.while": - ( - _, - sub_bwd_block_argument_to_value_map, - ) = argument_to_value(grad_op) - else: - sub_bwd_block_argument_to_value_map = ( - ValueDict() - ) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] append_backward_ops( @@ -822,12 +774,61 @@ def argument_to_value(while_op): no_grad_set, sub_backward_ops, sub_state, - sub_bwd_block_argument_to_value_map, ) # update input_grad map - update_input_grad_map( - op, input_grads, op.operands_source(), [] + update_input_grad_map(op, input_grads, origin_inputs) + elif op.name() == "pd_op.while": + origin_inputs = get_real_op_inputs(op) + # prepare while[cond, loop_vars, other_input] other_input's grad + while_block = op.as_while_op().body() + sub_state = state.copy(while_block) + for i, input in enumerate( + get_used_external_value(while_block) + ): + append_full_like( + 0.0, input, input, sub_state, backward_ops + ) + grad_value = sub_state.value_to_valuegrad[input][0] + output_grads.append( + [bwd_value_to_block_argument_map[grad_value[0]]] + if grad_value[0] + in bwd_value_to_block_argument_map + else grad_value + ) + + build_pipe_for_block(while_block) + with dynamic_shape_prim_vjp_guard(op, inputs): + input_grads = paddle.framework.core.call_vjp( + op, + inputs, + outputs, + output_grads, + input_grad_stopgradients, + ) + grad_op = bwd_block.ops[-1] + bwd_ops = [grad_op] + + # update grad_op structure + ( + _, + sub_bwd_value_to_block_argument_map, + ) = argument_to_value(grad_op) + while_grad_block = grad_op.as_while_op().body() + sub_backward_ops = [] + append_backward_ops( + op, + [input[0] for input in inputs], + [input_grad[0] for input_grad in input_grads], + while_block, + while_grad_block, + while_block.ops, + no_grad_set, + sub_backward_ops, + sub_state, + sub_bwd_value_to_block_argument_map, ) + # update input_grad map + update_input_grad_map(op, input_grads, origin_inputs) else: # create grad_op before_ops_num = len(bwd_block.ops) @@ -849,7 +850,7 @@ def argument_to_value(while_op): # update input_grad map update_input_grad_map( - op, input_grads, op.operands_source(), [] + op, input_grads, op.operands_source() ) update_bwdop_structure( @@ -984,7 +985,6 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): no_grad_set, backward_ops, state, - ValueDict(), ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 231ac07c2e98b1..42c51a532fd012 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -152,7 +152,7 @@ def body2(i, j, ten): class TestBuildModuleWithWhile2Op(unittest.TestCase): - def test_add_n_program(self): + def test_backward(self): main_program = paddle.static.Program() with paddle.pir.core.program_guard(main_program): i = paddle.full( @@ -198,6 +198,43 @@ def test_add_n_program(self): "cf.has_elements", ) + def test_backward_with_loop_var_same_to_extral_var(self): + main_program = paddle.static.Program() + with paddle.pir.core.program_guard(main_program): + i = paddle.full(shape=[1], fill_value=0) + x = paddle.full(shape=[1], fill_value=5) + y = paddle.full(shape=[1], fill_value=10) + i.stop_gradient = False + x.stop_gradient = False + y.stop_gradient = False + new_i, new_x = paddle.static.nn.while_loop( + lambda p, q: p < q, lambda p, q: [p + y, q + x], [i, x] + ) + + out = new_i - new_x + grad_outs = grad(out, [i, x, y]) + + self.assertEqual( + grad_outs[0].get_defining_op().name(), "pd_op.while" + ) + self.assertEqual( + grad_outs[1].get_defining_op().name(), "pd_op.add_n" + ) + self.assertEqual( + grad_outs[2].get_defining_op().name(), "pd_op.while" + ) + self.assertEqual( + main_program.global_block() + .ops[-3] + .as_while_op() + .body() + .ops[-1] + .operand_source(1) + .get_defining_op() + .name(), + "pd_op.add_grad", + ) + if __name__ == "__main__": unittest.main() From adb627a10f6a29edd7023a6ef8e503f988f2d361 Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 25 Dec 2023 02:01:31 +0000 Subject: [PATCH 20/21] modify ci --- python/paddle/autograd/ir_backward.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index a089628852ed4d..0e2f12f28c437b 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -625,6 +625,9 @@ def append_yield( if value_grad is None: continue + while value in state.inside_value_to_outside_value_map: + value = state.inside_value_to_outside_value_map[value] + if value in state.value_to_valuegrad: if len(state.value_to_valuegrad[value]) > 1: append_add_n(value) From 0a4617a51b5d04d1fc7e9919ad43ae8be8cab38b Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Mon, 25 Dec 2023 14:42:36 +0800 Subject: [PATCH 21/21] Update python/paddle/autograd/ir_backward.py --- python/paddle/autograd/ir_backward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 0e2f12f28c437b..a8ac124e6e2b15 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -358,7 +358,6 @@ def inverse_sort_op(ops): idx_2 = sorted_list.index(op_in) if idx_1 != idx_2: change_list.append((idx_1, idx_2)) - for idx_1, idx_2 in change_list: sorted_list[idx_1], sorted_list[idx_2] = ( sorted_list[idx_2],