diff --git a/test/ir/pir/cinn/symbolic/test_while_dy.py b/test/ir/pir/cinn/symbolic/test_while_dy.py index 64b4a936409b01..a8ba57ed394942 100644 --- a/test/ir/pir/cinn/symbolic/test_while_dy.py +++ b/test/ir/pir/cinn/symbolic/test_while_dy.py @@ -33,10 +33,11 @@ def __init__(self): def forward(self, x): loop_count = 0 - while x.sum() > 0 and loop_count < 1: + while loop_count < 1: y = paddle.exp(x) x = y - x loop_count += 1 + return x class TestWhile(unittest.TestCase): diff --git a/test/ir/pir/cinn/symbolic/test_while_st.py b/test/ir/pir/cinn/symbolic/test_while_st.py index 7ef7ae20ce2cd5..df9ecd9b2fccbf 100644 --- a/test/ir/pir/cinn/symbolic/test_while_st.py +++ b/test/ir/pir/cinn/symbolic/test_while_st.py @@ -33,11 +33,13 @@ def __init__(self): def forward(self, x): loop_count = 0 - while x.sum() > 0 and loop_count < 1: + while loop_count < 1: y = paddle.exp(x) x = y - x loop_count += 1 + return x + class TestWhile(unittest.TestCase): def setUp(self):