Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ def make_input_with_input_stopgradient(op):
return inputs, input_grad_stopgradients

def update_input_grad_map(op, input_grads, all_inputs):
_, fwd_value_to_block_argument_map = argument_to_value(op)
i = 0
for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)):
if not grad_semantic:
Expand Down Expand Up @@ -631,8 +630,11 @@ def append_yield(
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(value)
else:
new_value = return_map_value(
value, control_flow_value_to_copyvalue_map
)
value_grad = append_full_like(
0.0, value, value, state, backward_ops
0.0, new_value, value, state, backward_ops
)
input_grad = state.value_to_valuegrad[value][0][0]

Expand Down Expand Up @@ -762,16 +764,6 @@ def argument_to_value(while_op):
for sub_fwd_block, sub_bwd_block in zip(
op.blocks(), grad_op.blocks()
):
# update grad_op structure
if grad_op.name() == "pd_op.while":
(
_,
sub_bwd_block_argument_to_value_map,
) = argument_to_value(grad_op)
else:
sub_bwd_block_argument_to_value_map = (
ValueDict()
)
sub_state = state.copy(sub_fwd_block)
sub_backward_ops = []
append_backward_ops(
Expand All @@ -784,7 +776,6 @@ def argument_to_value(while_op):
no_grad_set,
sub_backward_ops,
sub_state,
sub_bwd_block_argument_to_value_map,
)
# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
Expand Down
95 changes: 47 additions & 48 deletions test/legacy_test/test_while_loop_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,66 +254,63 @@ def internal_body(j, init, sums):


class TestApiWhileLoop_Backward(unittest.TestCase):
# TODO(zhangbo): Support while grad exe for pir
# @test_with_pir_api
def test_while_loop_backward(self):
def cond(i, x):
return paddle.less_than(i, eleven)
with paddle.pir_utils.IrGuard():

def cond(i, x):
return paddle.less_than(i, eleven)

def body(i, x):
x = paddle.multiply(x=i, y=i)
i = paddle.increment(i)
return [i, x]

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.static.data(name='i', shape=[1], dtype='float32')
i.stop_gradient = False
i.persistable = True
eleven = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=11
)
one = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=1
)
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
x.persistable = True

def body(i, x):
x = paddle.multiply(x=i, y=i)
i = paddle.increment(i)
return [i, x]
out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
grad_list = append_backward(mean)

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.static.data(name='i', shape=[1], dtype='float32')
i.stop_gradient = False
i.persistable = True
eleven = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=11
)
one = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=1
place = (
base.CUDAPlace(0)
if core.is_compiled_with_cuda()
else base.CPUPlace()
)
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
x.persistable = True

out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
grad_list = append_backward(mean)
exe = base.Executor(place)

place = (
base.CUDAPlace(0)
if core.is_compiled_with_cuda()
else base.CPUPlace()
)
exe = base.Executor(place)
feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([100]).astype('float32')
i_grad = np.asarray([0]).astype('float32')
x_grad = np.asarray([0]).astype('float32')

feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([100]).astype('float32')
i_grad = np.asarray([110]).astype('float32')

if paddle.framework.in_pir_mode():
for p, g in grad_list:
if p == i:
if p.is_same(i):
di = g
elif p.is_same(x):
dx = g
res = exe.run(
main_program,
feed={'i': feed_i, 'x': feed_x},
fetch_list=[mean, di],
fetch_list=[mean, di, dx],
)
else:
res = exe.run(
main_program,
feed={'i': feed_i, 'x': feed_x},
fetch_list=[mean.name, i.grad_name, x.grad_name],
)
np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05)

@test_with_pir_api
def test_while_loop_backward2(self):
Expand Down Expand Up @@ -356,6 +353,7 @@ def body(i, x):
fetch_list = [out[1]]
for p, g in grad_list:
fetch_list.append(g)

res = exe.run(
main_program,
feed={'i': feed_i, 'x': feed_x},
Expand All @@ -367,6 +365,7 @@ def body(i, x):
feed={'i': feed_i, 'x': feed_x},
fetch_list=[out[1].name, i.grad_name, x.grad_name],
)

np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05)
Expand Down