diff --git a/test/legacy_test/test_eager_deletion_while_op.py b/test/legacy_test/test_eager_deletion_while_op.py index b909946d9704b7..67bb85c8963adf 100644 --- a/test/legacy_test/test_eager_deletion_while_op.py +++ b/test/legacy_test/test_eager_deletion_while_op.py @@ -22,14 +22,16 @@ import paddle from paddle import base -from paddle.base import core +from paddle.base import core, in_pir_mode from paddle.base.executor import Executor +from paddle.pir_utils import test_with_pir_api paddle.enable_static() base.core._set_eager_deletion_mode(0.0, 1.0, True) class TestEagerDeletionWhileOpBase(unittest.TestCase): + @test_with_pir_api def test_main(self): places = [ core.CPUPlace(), @@ -114,21 +116,21 @@ def run_main(self, place): sum_result.persistable = True tmp = paddle.unsqueeze(sum_result, axis=[0]) tmp = paddle.expand(tmp, [10, -1]) - fc = paddle.static.nn.fc(tmp, size=256) loss = paddle.mean(sum_result) optim = paddle.optimizer.Adam(learning_rate=1e-3) optim.minimize(loss) - gc_vars = core._get_eager_deletion_vars( - base.default_main_program().desc, [loss.name] - ) - self.assertEqual(len(gc_vars), 3) + if not in_pir_mode(): + gc_vars = core._get_eager_deletion_vars( + base.default_main_program().desc, [loss.name] + ) + self.assertEqual(len(gc_vars), 3) exe = Executor(self.place) - exe.run(base.default_startup_program()) + exe.run(paddle.static.default_startup_program()) - prog = base.default_main_program() + prog = paddle.static.default_main_program() for _ in range(5): d = []