diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 85d1e1f8541c7e..7f28a2f008b934 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5657,7 +5657,7 @@ def _gcd_body_fn(x, y): ) return (paddle.where(x < y, y, x), paddle.where(x < y, x, y)) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): while _gcd_cond_fn(x, y): x, y = _gcd_body_fn(x, y) @@ -5707,7 +5707,7 @@ def _gcd_body_fn(x, y): paddle.where_(x >= y, x, y), ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): while _gcd_cond_fn(x, y): y, x = _gcd_body_fn(x, y) diff --git a/test/legacy_test/test_gcd.py b/test/legacy_test/test_gcd.py index a7ec34eca42c7c..58a2c8689aafb0 100644 --- a/test/legacy_test/test_gcd.py +++ b/test/legacy_test/test_gcd.py @@ -19,6 +19,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -30,6 +31,7 @@ def setUp(self): self.x_shape = [1] self.y_shape = [1] + @test_with_pir_api def test_static_graph(self): startup_program = base.Program() train_program = base.Program()