diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 0cfe560fdc0687..29fac939ccbec8 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4996,10 +4996,10 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): ) axis = non_negative_axis(arr, axis) broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): values = ( paddle.to_tensor(values) - if not isinstance(values, paddle.Tensor) + if not isinstance(values, (paddle.Tensor, paddle.pir.OpResult)) else values ) if broadcast_shape: diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index b8027137069ed0..83194145bb18e7 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -20,6 +20,7 @@ import paddle from paddle.framework import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -49,10 +50,10 @@ def setUp(self): self.outputs = {'Result': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["Input", "Value"], "Result") + self.check_grad(["Input", "Value"], "Result", check_pir=True) def init_data(self): self.dtype = 'float64' @@ -114,10 +115,12 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): - self.check_grad_with_place(self.place, ["Input", "Value"], "Result") + self.check_grad_with_place( + self.place, ["Input", "Value"], "Result", check_pir=True + ) def init_data(self): self.dtype = np.uint16 @@ -146,6 +149,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_api_static(self): paddle.enable_static()