From 0e771c519996f67ce3e18f16192855e598717720 Mon Sep 17 00:00:00 2001 From: 6clc Date: Tue, 20 Feb 2024 19:23:08 +0800 Subject: [PATCH 1/2] cinn(test): fix test_while_st.py --- test/ir/pir/cinn/symbolic/test_while_st.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/symbolic/test_while_st.py b/test/ir/pir/cinn/symbolic/test_while_st.py index 7ef7ae20ce2cd5..df9ecd9b2fccbf 100644 --- a/test/ir/pir/cinn/symbolic/test_while_st.py +++ b/test/ir/pir/cinn/symbolic/test_while_st.py @@ -33,11 +33,13 @@ def __init__(self): def forward(self, x): loop_count = 0 - while x.sum() > 0 and loop_count < 1: + while loop_count < 1: y = paddle.exp(x) x = y - x loop_count += 1 + return x + class TestWhile(unittest.TestCase): def setUp(self): From d7fb121300b1b3f02c75b3d0ae68ea7e1f5281d5 Mon Sep 17 00:00:00 2001 From: 6clc Date: Tue, 20 Feb 2024 19:53:28 +0800 Subject: [PATCH 2/2] cinn(test): fix test_while_dy.py --- test/ir/pir/cinn/symbolic/test_while_dy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/symbolic/test_while_dy.py b/test/ir/pir/cinn/symbolic/test_while_dy.py index 64b4a936409b01..a8ba57ed394942 100644 --- a/test/ir/pir/cinn/symbolic/test_while_dy.py +++ b/test/ir/pir/cinn/symbolic/test_while_dy.py @@ -33,10 +33,11 @@ def __init__(self): def forward(self, x): loop_count = 0 - while x.sum() > 0 and loop_count < 1: + while loop_count < 1: y = paddle.exp(x) x = y - x loop_count += 1 + return x class TestWhile(unittest.TestCase):