diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index f804b3cdc91716..a8ac124e6e2b15 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -80,8 +80,12 @@ def append_full_like(float_value, copy_value, value, state, backward_ops): def get_real_op_inputs(op): - if op.name() in ["pd_op.if", "pd_op.while"]: + if op.name() == "pd_op.if": return get_used_external_value(op) + elif op.name() == "pd_op.while": + return op.operands_source() + get_used_external_value( + op.as_while_op().body() + ) else: return op.operands_source() @@ -373,7 +377,7 @@ def append_backward_ops( no_grad_set, backward_ops, state, - bwd_block_argument_to_value_map, + bwd_value_to_block_argument_map=ValueDict(), ): ''' add grad_op in order of topological inverse sort @@ -415,12 +419,10 @@ def append_backward_ops( else continue to next op. ''' - def return_value_to_copyvalue_map( - value, control_flow_value_to_copyvalue_map - ): + def return_map_value(value, map): output = value - while output in control_flow_value_to_copyvalue_map: - output = control_flow_value_to_copyvalue_map[output] + while output in map: + output = map[output] return output def append_add_n(value): @@ -446,9 +448,7 @@ def make_output_with_output_grad(op): output_grads = [] for i, value in enumerate(op.results()): new_value = [ - return_value_to_copyvalue_map( - value, control_flow_value_to_copyvalue_map - ) + return_map_value(value, control_flow_value_to_copyvalue_map) ] while value in state.inside_value_to_outside_value_map: value = state.inside_value_to_outside_value_map[value] @@ -496,33 +496,11 @@ def make_output_with_output_grad(op): outputs.append(new_value) grad_value = state.value_to_valuegrad[value][0] output_grads.append( - bwd_block_argument_to_value_map[grad_value[0]] - if grad_value[0] in bwd_block_argument_to_value_map + [bwd_value_to_block_argument_map[grad_value[0]]] + if grad_value[0] in bwd_value_to_block_argument_map else grad_value ) - if op.name() == "pd_op.while": - for i, input in enumerate(get_real_op_inputs(op)): - if i <= len(op.results()): - continue - if ( - input in state.value_to_valuegrad - and len(state.value_to_valuegrad[input]) > 1 - ): - append_add_n(input) - - if ( - input not in state.value_to_valuegrad - or state.value_to_valuegrad[input] == [] - ): - append_full_like(0.0, input, input, state, backward_ops) - - grad_value = state.value_to_valuegrad[input][0] - output_grads.append( - bwd_block_argument_to_value_map[grad_value[0]] - if grad_value[0] in bwd_block_argument_to_value_map - else grad_value - ) return zero_flag, outputs, output_grads def get_grad_semantic_info(op): @@ -555,7 +533,7 @@ def make_input_with_input_stopgradient(op): tmp_input = [] for tmp in input.get_defining_op().operands_source(): tmp_input.append( - return_value_to_copyvalue_map( + return_map_value( tmp, control_flow_value_to_copyvalue_map ) ) @@ -563,7 +541,7 @@ def make_input_with_input_stopgradient(op): inputs.append(tmp_input) else: tmp_input = [ - return_value_to_copyvalue_map( + return_map_value( input, control_flow_value_to_copyvalue_map ) ] @@ -584,9 +562,7 @@ def make_input_with_input_stopgradient(op): ) else: tmp_input = [ - return_value_to_copyvalue_map( - input, control_flow_value_to_copyvalue_map - ) + return_map_value(input, control_flow_value_to_copyvalue_map) ] inputs.append(tmp_input) @@ -597,13 +573,13 @@ def make_input_with_input_stopgradient(op): return inputs, input_grad_stopgradients - def update_input_grad_map(op, input_grads, origin_inputs): + def update_input_grad_map(op, input_grads, all_inputs): + _, fwd_value_to_block_argument_map = argument_to_value(op) i = 0 - for input, grad_semantic in zip( - origin_inputs, get_grad_semantic_info(op) - ): + for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)): if not grad_semantic: continue + if ( input.get_defining_op() is not None and input.get_defining_op().name() == "builtin.combine" @@ -615,9 +591,6 @@ def update_input_grad_map(op, input_grads, origin_inputs): ) else: input_grad = input_grads[i] - if input in fwd_block_argument_to_value_map: - input = fwd_block_argument_to_value_map[input] - if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) else: @@ -625,27 +598,29 @@ def update_input_grad_map(op, input_grads, origin_inputs): i += 1 def append_yield( - block, base_op, base_grad_op, base_inputs, base_inputs_grad + block, + base_op, + base_grad_op, + base_inputs, + base_inputs_grad, ): + ( + fwd_block_argument_to_value_map, + fwd_value_to_block_argument_map, + ) = argument_to_value(base_op) with block: inputs_grad = [] if base_op.name() == "pd_op.while": new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op) inputs_grad.append(new_cond) - output_grads = base_grad_op.operands_source() - # output_grad = [new_cond, loop_vars(fwd_output_grad)] - # base_inputs = [cond, loop_vars(fwd_input)] - assert len(output_grads) <= len( - base_inputs - ), "while op's inputs size should less than while_grad op's inputs size" - - else: - output_grads = [None] * len(base_inputs) + for idx in range(len(base_inputs[: base_op.num_operands()])): + operands = base_inputs[idx] + if operands in fwd_value_to_block_argument_map: + operands = fwd_value_to_block_argument_map[operands] + base_inputs[idx] = operands - for value, value_grad, output_grad in zip( - base_inputs, base_inputs_grad, output_grads - ): + for value, value_grad in zip(base_inputs, base_inputs_grad): if value_grad is None: continue @@ -659,12 +634,6 @@ def append_yield( value_grad = append_full_like( 0.0, value, value, state, backward_ops ) - - # if base_op.name() == "pd_op.while": - # input_grad = paddle.add( - # output_grad, state.value_to_valuegrad[value][0][0] - # ) - # else: input_grad = state.value_to_valuegrad[value][0][0] inputs_grad.append(input_grad) @@ -672,6 +641,9 @@ def append_yield( paddle.base.libpaddle.pir.cf_yield(inputs_grad) def argument_to_value(while_op): + if while_op.name() != "pd_op.while": + return ValueDict(), ValueDict() + assert len(while_op.as_while_op().block_arguments()) + 1 == len( while_op.operands_source() ), "while op's block_arguments size + 1 should same to whiel op's operands_source" @@ -682,7 +654,7 @@ def argument_to_value(while_op): while_op.operands_source()[1:], ): arg_to_value_map[arg] = value - value_to_arg_map[value] = [arg] + value_to_arg_map[value] = arg return arg_to_value_map, value_to_arg_map # there are four patterns: @@ -695,9 +667,6 @@ def argument_to_value(while_op): # tuple_push value to pop value control_flow_value_to_copyvalue_map = ValueDict() control_flow_copyvalue_to_value_map = ValueDict() - # fwd_whileop's blockargument to fwd_whileop's input value - fwd_block_argument_to_value_map = ValueDict() - # bwd_whileop's input value to bwd_whileop's blockargument if ( len(effective_forward_ops) > 1 @@ -708,7 +677,6 @@ def argument_to_value(while_op): # while op yield [cond, loop_vars], # but outputs only has loop_vars. inside_outputs = yield_op.operands_source()[1:] - fwd_block_argument_to_value_map, _ = argument_to_value(base_op) else: inside_outputs = yield_op.operands_source() @@ -776,8 +744,8 @@ def argument_to_value(while_op): if len(output_grads) == 0 or all(zero_flag): continue - if op.name() in ["pd_op.if", "pd_op.while"]: - origin_inputs = get_used_external_value(op) + if op.name() == "pd_op.if": + origin_inputs = get_real_op_inputs(op) for sub_block in op.blocks(): build_pipe_for_block(sub_block) with dynamic_shape_prim_vjp_guard(op, inputs): @@ -820,6 +788,58 @@ def argument_to_value(while_op): ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) + elif op.name() == "pd_op.while": + origin_inputs = get_real_op_inputs(op) + # prepare while[cond, loop_vars, other_input] other_input's grad + while_block = op.as_while_op().body() + sub_state = state.copy(while_block) + for i, input in enumerate( + get_used_external_value(while_block) + ): + append_full_like( + 0.0, input, input, sub_state, backward_ops + ) + grad_value = sub_state.value_to_valuegrad[input][0] + output_grads.append( + [bwd_value_to_block_argument_map[grad_value[0]]] + if grad_value[0] + in bwd_value_to_block_argument_map + else grad_value + ) + + build_pipe_for_block(while_block) + with dynamic_shape_prim_vjp_guard(op, inputs): + input_grads = paddle.framework.core.call_vjp( + op, + inputs, + outputs, + output_grads, + input_grad_stopgradients, + ) + grad_op = bwd_block.ops[-1] + bwd_ops = [grad_op] + + # update grad_op structure + ( + _, + sub_bwd_value_to_block_argument_map, + ) = argument_to_value(grad_op) + while_grad_block = grad_op.as_while_op().body() + sub_backward_ops = [] + append_backward_ops( + op, + [input[0] for input in inputs], + [input_grad[0] for input_grad in input_grads], + while_block, + while_grad_block, + while_block.ops, + no_grad_set, + sub_backward_ops, + sub_state, + sub_bwd_value_to_block_argument_map, + ) + # update input_grad map + update_input_grad_map(op, input_grads, origin_inputs) else: # create grad_op before_ops_num = len(bwd_block.ops) diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 9165cec5ac077e..45b68b9fcf125b 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -152,7 +152,7 @@ def body2(i, j, ten): class TestBuildModuleWithWhile2Op(unittest.TestCase): - def test_add_n_program(self): + def test_backward(self): main_program = paddle.static.Program() with paddle.pir.core.program_guard(main_program): i = paddle.full( @@ -189,6 +189,43 @@ def test_add_n_program(self): "cf.has_elements", ) + def test_backward_with_loop_var_same_to_extral_var(self): + main_program = paddle.static.Program() + with paddle.pir.core.program_guard(main_program): + i = paddle.full(shape=[1], fill_value=0) + x = paddle.full(shape=[1], fill_value=5) + y = paddle.full(shape=[1], fill_value=10) + i.stop_gradient = False + x.stop_gradient = False + y.stop_gradient = False + new_i, new_x = paddle.static.nn.while_loop( + lambda p, q: p < q, lambda p, q: [p + y, q + x], [i, x] + ) + + out = new_i - new_x + grad_outs = grad(out, [i, x, y]) + + self.assertEqual( + grad_outs[0].get_defining_op().name(), "pd_op.while" + ) + self.assertEqual( + grad_outs[1].get_defining_op().name(), "pd_op.add_n" + ) + self.assertEqual( + grad_outs[2].get_defining_op().name(), "pd_op.while" + ) + self.assertEqual( + main_program.global_block() + .ops[-3] + .as_while_op() + .body() + .ops[-1] + .operand_source(1) + .get_defining_op() + .name(), + "pd_op.add_grad", + ) + if __name__ == "__main__": unittest.main()