Skip to content

Commit 12ef958

Browse files
xiaoguoguo626807Wanglongzhi2001
authored andcommitted
【pir】delete wrong old ir while_loop test add pir test (PaddlePaddle#60328)
* optimize backward * modfiy while_loop * delete print * modify append_full_like use copy value * clear * clear
1 parent 68fccb9 commit 12ef958

2 files changed

Lines changed: 51 additions & 61 deletions

File tree

python/paddle/autograd/ir_backward.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,6 @@ def make_input_with_input_stopgradient(op):
574574
return inputs, input_grad_stopgradients
575575

576576
def update_input_grad_map(op, input_grads, all_inputs):
577-
_, fwd_value_to_block_argument_map = argument_to_value(op)
578577
i = 0
579578
for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)):
580579
if not grad_semantic:
@@ -631,8 +630,11 @@ def append_yield(
631630
if len(state.value_to_valuegrad[value]) > 1:
632631
append_add_n(value)
633632
else:
633+
new_value = return_map_value(
634+
value, control_flow_value_to_copyvalue_map
635+
)
634636
value_grad = append_full_like(
635-
0.0, value, value, state, backward_ops
637+
0.0, new_value, value, state, backward_ops
636638
)
637639
input_grad = state.value_to_valuegrad[value][0][0]
638640

@@ -762,16 +764,6 @@ def argument_to_value(while_op):
762764
for sub_fwd_block, sub_bwd_block in zip(
763765
op.blocks(), grad_op.blocks()
764766
):
765-
# update grad_op structure
766-
if grad_op.name() == "pd_op.while":
767-
(
768-
_,
769-
sub_bwd_block_argument_to_value_map,
770-
) = argument_to_value(grad_op)
771-
else:
772-
sub_bwd_block_argument_to_value_map = (
773-
ValueDict()
774-
)
775767
sub_state = state.copy(sub_fwd_block)
776768
sub_backward_ops = []
777769
append_backward_ops(
@@ -784,7 +776,6 @@ def argument_to_value(while_op):
784776
no_grad_set,
785777
sub_backward_ops,
786778
sub_state,
787-
sub_bwd_block_argument_to_value_map,
788779
)
789780
# update input_grad map
790781
update_input_grad_map(op, input_grads, origin_inputs)

test/legacy_test/test_while_loop_op.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -254,66 +254,63 @@ def internal_body(j, init, sums):
254254

255255

256256
class TestApiWhileLoop_Backward(unittest.TestCase):
257-
# TODO(zhangbo): Support while grad exe for pir
258-
# @test_with_pir_api
259257
def test_while_loop_backward(self):
260-
def cond(i, x):
261-
return paddle.less_than(i, eleven)
258+
with paddle.pir_utils.IrGuard():
259+
260+
def cond(i, x):
261+
return paddle.less_than(i, eleven)
262+
263+
def body(i, x):
264+
x = paddle.multiply(x=i, y=i)
265+
i = paddle.increment(i)
266+
return [i, x]
267+
268+
main_program = paddle.static.Program()
269+
startup_program = paddle.static.Program()
270+
with paddle.static.program_guard(main_program, startup_program):
271+
i = paddle.static.data(name='i', shape=[1], dtype='float32')
272+
i.stop_gradient = False
273+
i.persistable = True
274+
eleven = paddle.tensor.fill_constant(
275+
shape=[1], dtype='float32', value=11
276+
)
277+
one = paddle.tensor.fill_constant(
278+
shape=[1], dtype='float32', value=1
279+
)
280+
x = paddle.static.data(name='x', shape=[1], dtype='float32')
281+
x.stop_gradient = False
282+
x.persistable = True
262283

263-
def body(i, x):
264-
x = paddle.multiply(x=i, y=i)
265-
i = paddle.increment(i)
266-
return [i, x]
284+
out = paddle.static.nn.while_loop(cond, body, [i, x])
285+
mean = paddle.mean(out[1])
286+
grad_list = append_backward(mean)
267287

268-
main_program = paddle.static.Program()
269-
startup_program = paddle.static.Program()
270-
with paddle.static.program_guard(main_program, startup_program):
271-
i = paddle.static.data(name='i', shape=[1], dtype='float32')
272-
i.stop_gradient = False
273-
i.persistable = True
274-
eleven = paddle.tensor.fill_constant(
275-
shape=[1], dtype='float32', value=11
276-
)
277-
one = paddle.tensor.fill_constant(
278-
shape=[1], dtype='float32', value=1
288+
place = (
289+
base.CUDAPlace(0)
290+
if core.is_compiled_with_cuda()
291+
else base.CPUPlace()
279292
)
280-
x = paddle.static.data(name='x', shape=[1], dtype='float32')
281-
x.stop_gradient = False
282-
x.persistable = True
283-
284-
out = paddle.static.nn.while_loop(cond, body, [i, x])
285-
mean = paddle.mean(out[1])
286-
grad_list = append_backward(mean)
293+
exe = base.Executor(place)
287294

288-
place = (
289-
base.CUDAPlace(0)
290-
if core.is_compiled_with_cuda()
291-
else base.CPUPlace()
292-
)
293-
exe = base.Executor(place)
295+
feed_i = np.ones(1).astype('float32')
296+
feed_x = np.ones(1).astype('float32')
297+
data = np.asarray([100]).astype('float32')
298+
i_grad = np.asarray([0]).astype('float32')
299+
x_grad = np.asarray([0]).astype('float32')
294300

295-
feed_i = np.ones(1).astype('float32')
296-
feed_x = np.ones(1).astype('float32')
297-
data = np.asarray([100]).astype('float32')
298-
i_grad = np.asarray([110]).astype('float32')
299-
300-
if paddle.framework.in_pir_mode():
301301
for p, g in grad_list:
302-
if p == i:
302+
if p.is_same(i):
303303
di = g
304+
elif p.is_same(x):
305+
dx = g
304306
res = exe.run(
305307
main_program,
306308
feed={'i': feed_i, 'x': feed_x},
307-
fetch_list=[mean, di],
309+
fetch_list=[mean, di, dx],
308310
)
309-
else:
310-
res = exe.run(
311-
main_program,
312-
feed={'i': feed_i, 'x': feed_x},
313-
fetch_list=[mean.name, i.grad_name, x.grad_name],
314-
)
315-
np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
316-
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
311+
np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
312+
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
313+
np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05)
317314

318315
@test_with_pir_api
319316
def test_while_loop_backward2(self):
@@ -356,6 +353,7 @@ def body(i, x):
356353
fetch_list = [out[1]]
357354
for p, g in grad_list:
358355
fetch_list.append(g)
356+
359357
res = exe.run(
360358
main_program,
361359
feed={'i': feed_i, 'x': feed_x},
@@ -367,6 +365,7 @@ def body(i, x):
367365
feed={'i': feed_i, 'x': feed_x},
368366
fetch_list=[out[1].name, i.grad_name, x.grad_name],
369367
)
368+
370369
np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05)
371370
np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05)
372371
np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05)

0 commit comments

Comments
 (0)