@@ -254,66 +254,63 @@ def internal_body(j, init, sums):
254254
255255
256256class TestApiWhileLoop_Backward (unittest .TestCase ):
257- # TODO(zhangbo): Support while grad exe for pir
258- # @test_with_pir_api
259257 def test_while_loop_backward (self ):
260- def cond (i , x ):
261- return paddle .less_than (i , eleven )
258+ with paddle .pir_utils .IrGuard ():
259+
260+ def cond (i , x ):
261+ return paddle .less_than (i , eleven )
262+
263+ def body (i , x ):
264+ x = paddle .multiply (x = i , y = i )
265+ i = paddle .increment (i )
266+ return [i , x ]
267+
268+ main_program = paddle .static .Program ()
269+ startup_program = paddle .static .Program ()
270+ with paddle .static .program_guard (main_program , startup_program ):
271+ i = paddle .static .data (name = 'i' , shape = [1 ], dtype = 'float32' )
272+ i .stop_gradient = False
273+ i .persistable = True
274+ eleven = paddle .tensor .fill_constant (
275+ shape = [1 ], dtype = 'float32' , value = 11
276+ )
277+ one = paddle .tensor .fill_constant (
278+ shape = [1 ], dtype = 'float32' , value = 1
279+ )
280+ x = paddle .static .data (name = 'x' , shape = [1 ], dtype = 'float32' )
281+ x .stop_gradient = False
282+ x .persistable = True
262283
263- def body (i , x ):
264- x = paddle .multiply (x = i , y = i )
265- i = paddle .increment (i )
266- return [i , x ]
284+ out = paddle .static .nn .while_loop (cond , body , [i , x ])
285+ mean = paddle .mean (out [1 ])
286+ grad_list = append_backward (mean )
267287
268- main_program = paddle .static .Program ()
269- startup_program = paddle .static .Program ()
270- with paddle .static .program_guard (main_program , startup_program ):
271- i = paddle .static .data (name = 'i' , shape = [1 ], dtype = 'float32' )
272- i .stop_gradient = False
273- i .persistable = True
274- eleven = paddle .tensor .fill_constant (
275- shape = [1 ], dtype = 'float32' , value = 11
276- )
277- one = paddle .tensor .fill_constant (
278- shape = [1 ], dtype = 'float32' , value = 1
288+ place = (
289+ base .CUDAPlace (0 )
290+ if core .is_compiled_with_cuda ()
291+ else base .CPUPlace ()
279292 )
280- x = paddle .static .data (name = 'x' , shape = [1 ], dtype = 'float32' )
281- x .stop_gradient = False
282- x .persistable = True
283-
284- out = paddle .static .nn .while_loop (cond , body , [i , x ])
285- mean = paddle .mean (out [1 ])
286- grad_list = append_backward (mean )
293+ exe = base .Executor (place )
287294
288- place = (
289- base .CUDAPlace (0 )
290- if core .is_compiled_with_cuda ()
291- else base .CPUPlace ()
292- )
293- exe = base .Executor (place )
295+ feed_i = np .ones (1 ).astype ('float32' )
296+ feed_x = np .ones (1 ).astype ('float32' )
297+ data = np .asarray ([100 ]).astype ('float32' )
298+ i_grad = np .asarray ([0 ]).astype ('float32' )
299+ x_grad = np .asarray ([0 ]).astype ('float32' )
294300
295- feed_i = np .ones (1 ).astype ('float32' )
296- feed_x = np .ones (1 ).astype ('float32' )
297- data = np .asarray ([100 ]).astype ('float32' )
298- i_grad = np .asarray ([110 ]).astype ('float32' )
299-
300- if paddle .framework .in_pir_mode ():
301301 for p , g in grad_list :
302- if p == i :
302+ if p . is_same ( i ) :
303303 di = g
304+ elif p .is_same (x ):
305+ dx = g
304306 res = exe .run (
305307 main_program ,
306308 feed = {'i' : feed_i , 'x' : feed_x },
307- fetch_list = [mean , di ],
309+ fetch_list = [mean , di , dx ],
308310 )
309- else :
310- res = exe .run (
311- main_program ,
312- feed = {'i' : feed_i , 'x' : feed_x },
313- fetch_list = [mean .name , i .grad_name , x .grad_name ],
314- )
315- np .testing .assert_allclose (np .asarray (res [0 ]), data , rtol = 1e-05 )
316- np .testing .assert_allclose (np .asarray (res [1 ]), i_grad , rtol = 1e-05 )
311+ np .testing .assert_allclose (np .asarray (res [0 ]), data , rtol = 1e-05 )
312+ np .testing .assert_allclose (np .asarray (res [1 ]), i_grad , rtol = 1e-05 )
313+ np .testing .assert_allclose (np .asarray (res [2 ]), x_grad , rtol = 1e-05 )
317314
318315 @test_with_pir_api
319316 def test_while_loop_backward2 (self ):
@@ -356,6 +353,7 @@ def body(i, x):
356353 fetch_list = [out [1 ]]
357354 for p , g in grad_list :
358355 fetch_list .append (g )
356+
359357 res = exe .run (
360358 main_program ,
361359 feed = {'i' : feed_i , 'x' : feed_x },
@@ -367,6 +365,7 @@ def body(i, x):
367365 feed = {'i' : feed_i , 'x' : feed_x },
368366 fetch_list = [out [1 ].name , i .grad_name , x .grad_name ],
369367 )
368+
370369 np .testing .assert_allclose (np .asarray (res [0 ]), data , rtol = 1e-05 )
371370 np .testing .assert_allclose (np .asarray (res [1 ]), i_grad , rtol = 1e-05 )
372371 np .testing .assert_allclose (np .asarray (res [2 ]), x_grad , rtol = 1e-05 )
0 commit comments