Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
001d799
optimize backward
xiaoguoguo626807 Dec 8, 2023
05ca298
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 11, 2023
4fd113e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 12, 2023
8f60538
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 13, 2023
8854896
[PIR] add vjp interface for while op
winter-wang Dec 12, 2023
7e177f6
[PIR] fix ci error.
winter-wang Dec 13, 2023
11c8656
modify while stopgradient
xiaoguoguo626807 Dec 14, 2023
d8c3936
merge
xiaoguoguo626807 Dec 14, 2023
da62e16
merge
xiaoguoguo626807 Dec 15, 2023
67ed811
merge
xiaoguoguo626807 Dec 15, 2023
30bba32
modify while grad bug
xiaoguoguo626807 Dec 18, 2023
53f2920
merge
xiaoguoguo626807 Dec 18, 2023
fde161c
modify while grad op
xiaoguoguo626807 Dec 18, 2023
fdc12c7
modify
xiaoguoguo626807 Dec 18, 2023
e3d19b9
increment vp
xiaoguoguo626807 Dec 19, 2023
600d99c
merge
xiaoguoguo626807 Dec 20, 2023
0913436
[PIR] add get_used_external_value interface for block.
winter-wang Dec 19, 2023
63344b7
while case
xiaoguoguo626807 Dec 20, 2023
59ad2fc
delete print
xiaoguoguo626807 Dec 20, 2023
f4eceb6
delete print
xiaoguoguo626807 Dec 20, 2023
1c9eb96
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 20, 2023
4beaa79
Merge branch 'develop' into while_2
xiaoguoguo626807 Dec 20, 2023
df0b46a
[PIR] add unit_test for get_used_external_value
winter-wang Dec 20, 2023
65083df
modify while_loop
xiaoguoguo626807 Dec 21, 2023
f2f4fa0
Merge branch 'while_2' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Dec 21, 2023
f8e3ac4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 21, 2023
95bc3d7
code_style
xiaoguoguo626807 Dec 21, 2023
37e807c
modofy ci bug
xiaoguoguo626807 Dec 21, 2023
52afa31
Merge branch 'develop', commit 'refs/pull/60159/head' of https://gith…
xiaoguoguo626807 Dec 21, 2023
48de124
modify while api
xiaoguoguo626807 Dec 22, 2023
a7f13c9
merge
xiaoguoguo626807 Dec 25, 2023
adb627a
modify ci
xiaoguoguo626807 Dec 25, 2023
0a4617a
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 93 additions & 72 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def append_full_like(float_value, copy_value, value, state, backward_ops):


def get_real_op_inputs(op):
if op.name() in ["pd_op.if", "pd_op.while"]:
if op.name() == "pd_op.if":
return get_used_external_value(op)
elif op.name() == "pd_op.while":
return op.operands_source() + get_used_external_value(
op.as_while_op().body()
)
else:
return op.operands_source()

Expand Down Expand Up @@ -354,6 +358,7 @@ def inverse_sort_op(ops):
idx_2 = sorted_list.index(op_in)
if idx_1 != idx_2:
change_list.append((idx_1, idx_2))

for idx_1, idx_2 in change_list:
sorted_list[idx_1], sorted_list[idx_2] = (
sorted_list[idx_2],
Expand All @@ -373,7 +378,7 @@ def append_backward_ops(
no_grad_set,
backward_ops,
state,
bwd_block_argument_to_value_map,
bwd_value_to_block_argument_map=ValueDict(),
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -415,12 +420,10 @@ def append_backward_ops(
else continue to next op.
'''

def return_value_to_copyvalue_map(
value, control_flow_value_to_copyvalue_map
):
def return_map_value(value, map):
output = value
while output in control_flow_value_to_copyvalue_map:
output = control_flow_value_to_copyvalue_map[output]
while output in map:
output = map[output]
return output

def append_add_n(value):
Expand All @@ -446,9 +449,7 @@ def make_output_with_output_grad(op):
output_grads = []
for i, value in enumerate(op.results()):
new_value = [
return_value_to_copyvalue_map(
value, control_flow_value_to_copyvalue_map
)
return_map_value(value, control_flow_value_to_copyvalue_map)
]
while value in state.inside_value_to_outside_value_map:
value = state.inside_value_to_outside_value_map[value]
Expand Down Expand Up @@ -496,33 +497,11 @@ def make_output_with_output_grad(op):
outputs.append(new_value)
grad_value = state.value_to_valuegrad[value][0]
output_grads.append(
bwd_block_argument_to_value_map[grad_value[0]]
if grad_value[0] in bwd_block_argument_to_value_map
[bwd_value_to_block_argument_map[grad_value[0]]]
if grad_value[0] in bwd_value_to_block_argument_map
else grad_value
)

if op.name() == "pd_op.while":
for i, input in enumerate(get_real_op_inputs(op)):
if i <= len(op.results()):
continue
if (
input in state.value_to_valuegrad
and len(state.value_to_valuegrad[input]) > 1
):
append_add_n(input)

if (
input not in state.value_to_valuegrad
or state.value_to_valuegrad[input] == []
):
append_full_like(0.0, input, input, state, backward_ops)

grad_value = state.value_to_valuegrad[input][0]
output_grads.append(
bwd_block_argument_to_value_map[grad_value[0]]
if grad_value[0] in bwd_block_argument_to_value_map
else grad_value
)
return zero_flag, outputs, output_grads

def get_grad_semantic_info(op):
Expand Down Expand Up @@ -555,15 +534,15 @@ def make_input_with_input_stopgradient(op):
tmp_input = []
for tmp in input.get_defining_op().operands_source():
tmp_input.append(
return_value_to_copyvalue_map(
return_map_value(
tmp, control_flow_value_to_copyvalue_map
)
)

inputs.append(tmp_input)
else:
tmp_input = [
return_value_to_copyvalue_map(
return_map_value(
input, control_flow_value_to_copyvalue_map
)
]
Expand All @@ -584,9 +563,7 @@ def make_input_with_input_stopgradient(op):
)
else:
tmp_input = [
return_value_to_copyvalue_map(
input, control_flow_value_to_copyvalue_map
)
return_map_value(input, control_flow_value_to_copyvalue_map)
]
inputs.append(tmp_input)

Expand All @@ -597,13 +574,13 @@ def make_input_with_input_stopgradient(op):

return inputs, input_grad_stopgradients

def update_input_grad_map(op, input_grads, origin_inputs):
def update_input_grad_map(op, input_grads, all_inputs):
_, fwd_value_to_block_argument_map = argument_to_value(op)
i = 0
for input, grad_semantic in zip(
origin_inputs, get_grad_semantic_info(op)
):
for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)):
if not grad_semantic:
continue

if (
input.get_defining_op() is not None
and input.get_defining_op().name() == "builtin.combine"
Expand All @@ -615,37 +592,36 @@ def update_input_grad_map(op, input_grads, origin_inputs):
)
else:
input_grad = input_grads[i]
if input in fwd_block_argument_to_value_map:
input = fwd_block_argument_to_value_map[input]

if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad)
else:
state.value_to_valuegrad[input].append([input_grad])
i += 1

def append_yield(
block, base_op, base_grad_op, base_inputs, base_inputs_grad
block,
base_op,
base_grad_op,
base_inputs,
base_inputs_grad,
):
(
fwd_block_argument_to_value_map,
fwd_value_to_block_argument_map,
) = argument_to_value(base_op)
with block:
inputs_grad = []
if base_op.name() == "pd_op.while":
new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op)
inputs_grad.append(new_cond)

output_grads = base_grad_op.operands_source()
# output_grad = [new_cond, loop_vars(fwd_output_grad)]
# base_inputs = [cond, loop_vars(fwd_input)]
assert len(output_grads) <= len(
base_inputs
), "while op's inputs size should less than while_grad op's inputs size"
for idx in range(len(base_inputs[: base_op.num_operands()])):
operands = base_inputs[idx]
if operands in fwd_value_to_block_argument_map:
operands = fwd_value_to_block_argument_map[operands]
base_inputs[idx] = operands

else:
output_grads = [None] * len(base_inputs)

for value, value_grad, output_grad in zip(
base_inputs, base_inputs_grad, output_grads
):
for value, value_grad in zip(base_inputs, base_inputs_grad):
if value_grad is None:
continue

Expand All @@ -659,19 +635,16 @@ def append_yield(
value_grad = append_full_like(
0.0, value, value, state, backward_ops
)

# if base_op.name() == "pd_op.while":
# input_grad = paddle.add(
# output_grad, state.value_to_valuegrad[value][0][0]
# )
# else:
input_grad = state.value_to_valuegrad[value][0][0]

inputs_grad.append(input_grad)

paddle.base.libpaddle.pir.cf_yield(inputs_grad)

def argument_to_value(while_op):
if while_op.name() != "pd_op.while":
return ValueDict(), ValueDict()

assert len(while_op.as_while_op().block_arguments()) + 1 == len(
while_op.operands_source()
), "while op's block_arguments size + 1 should same to whiel op's operands_source"
Expand All @@ -682,7 +655,7 @@ def argument_to_value(while_op):
while_op.operands_source()[1:],
):
arg_to_value_map[arg] = value
value_to_arg_map[value] = [arg]
value_to_arg_map[value] = arg
return arg_to_value_map, value_to_arg_map

# there are four patterns:
Expand All @@ -695,9 +668,6 @@ def argument_to_value(while_op):
# tuple_push value to pop value
control_flow_value_to_copyvalue_map = ValueDict()
control_flow_copyvalue_to_value_map = ValueDict()
# fwd_whileop's blockargument to fwd_whileop's input value
fwd_block_argument_to_value_map = ValueDict()
# bwd_whileop's input value to bwd_whileop's blockargument

if (
len(effective_forward_ops) > 1
Expand All @@ -708,7 +678,6 @@ def argument_to_value(while_op):
# while op yield [cond, loop_vars],
# but outputs only has loop_vars.
inside_outputs = yield_op.operands_source()[1:]
fwd_block_argument_to_value_map, _ = argument_to_value(base_op)
else:
inside_outputs = yield_op.operands_source()

Expand Down Expand Up @@ -776,8 +745,8 @@ def argument_to_value(while_op):
if len(output_grads) == 0 or all(zero_flag):
continue

if op.name() in ["pd_op.if", "pd_op.while"]:
origin_inputs = get_used_external_value(op)
if op.name() == "pd_op.if":
origin_inputs = get_real_op_inputs(op)
for sub_block in op.blocks():
build_pipe_for_block(sub_block)
with dynamic_shape_prim_vjp_guard(op, inputs):
Expand Down Expand Up @@ -820,6 +789,58 @@ def argument_to_value(while_op):
)
# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
elif op.name() == "pd_op.while":
origin_inputs = get_real_op_inputs(op)
# prepare while[cond, loop_vars, other_input] other_input's grad
while_block = op.as_while_op().body()
sub_state = state.copy(while_block)
for i, input in enumerate(
get_used_external_value(while_block)
):
append_full_like(
0.0, input, input, sub_state, backward_ops
)
grad_value = sub_state.value_to_valuegrad[input][0]
output_grads.append(
[bwd_value_to_block_argument_map[grad_value[0]]]
if grad_value[0]
in bwd_value_to_block_argument_map
else grad_value
)

build_pipe_for_block(while_block)
with dynamic_shape_prim_vjp_guard(op, inputs):
input_grads = paddle.framework.core.call_vjp(
op,
inputs,
outputs,
output_grads,
input_grad_stopgradients,
)
grad_op = bwd_block.ops[-1]
bwd_ops = [grad_op]

# update grad_op structure
(
_,
sub_bwd_value_to_block_argument_map,
) = argument_to_value(grad_op)
while_grad_block = grad_op.as_while_op().body()
sub_backward_ops = []
append_backward_ops(
op,
[input[0] for input in inputs],
[input_grad[0] for input_grad in input_grads],
while_block,
while_grad_block,
while_block.ops,
no_grad_set,
sub_backward_ops,
sub_state,
sub_bwd_value_to_block_argument_map,
)
# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
else:
# create grad_op
before_ops_num = len(bwd_block.ops)
Expand Down
39 changes: 38 additions & 1 deletion test/ir/pir/test_while_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def body2(i, j, ten):


class TestBuildModuleWithWhile2Op(unittest.TestCase):
def test_add_n_program(self):
def test_backward(self):
main_program = paddle.static.Program()
with paddle.pir.core.program_guard(main_program):
i = paddle.full(
Expand Down Expand Up @@ -189,6 +189,43 @@ def test_add_n_program(self):
"cf.has_elements",
)

def test_backward_with_loop_var_same_to_extral_var(self):
main_program = paddle.static.Program()
with paddle.pir.core.program_guard(main_program):
i = paddle.full(shape=[1], fill_value=0)
x = paddle.full(shape=[1], fill_value=5)
y = paddle.full(shape=[1], fill_value=10)
i.stop_gradient = False
x.stop_gradient = False
y.stop_gradient = False
new_i, new_x = paddle.static.nn.while_loop(
lambda p, q: p < q, lambda p, q: [p + y, q + x], [i, x]
)

out = new_i - new_x
grad_outs = grad(out, [i, x, y])

self.assertEqual(
grad_outs[0].get_defining_op().name(), "pd_op.while"
)
self.assertEqual(
grad_outs[1].get_defining_op().name(), "pd_op.add_n"
)
self.assertEqual(
grad_outs[2].get_defining_op().name(), "pd_op.while"
)
self.assertEqual(
main_program.global_block()
.ops[-3]
.as_while_op()
.body()
.ops[-1]
.operand_source(1)
.get_defining_op()
.name(),
"pd_op.add_grad",
)


if __name__ == "__main__":
unittest.main()