diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 5a8b433cea52ef..5cde0289541730 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -20,7 +20,7 @@ from ..base import core from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import Variable, in_dygraph_mode +from ..base.framework import Variable, in_dygraph_mode, in_dynamic_or_pir_mode from ..base.layer_helper import LayerHelper from ..framework import _current_expected_place from ..nn import BatchNorm2D, Conv2D, Layer, ReLU, Sequential @@ -1547,7 +1547,7 @@ def roi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None): output_size = (output_size, output_size) pooled_height, pooled_width = output_size - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): assert ( boxes_num is not None ), "boxes_num should not be None in dygraph mode." @@ -1707,7 +1707,7 @@ def roi_align( output_size = (output_size, output_size) pooled_height, pooled_width = output_size - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): assert ( boxes_num is not None ), "boxes_num should not be None in dygraph mode." diff --git a/test/legacy_test/test_ops_roi_align.py b/test/legacy_test/test_ops_roi_align.py index 61d06b9ebb7170..7432116d95d83b 100644 --- a/test/legacy_test/test_ops_roi_align.py +++ b/test/legacy_test/test_ops_roi_align.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.pir_utils import test_with_pir_api from paddle.vision.ops import RoIAlign, roi_align @@ -81,11 +82,13 @@ def test_roi_align_functional_dynamic(self): self.roi_align_functional(3) self.roi_align_functional(output_size=(3, 4)) + @test_with_pir_api def test_roi_align_functional_static(self): paddle.enable_static() self.roi_align_functional(3) paddle.disable_static() + @test_with_pir_api def test_RoIAlign(self): roi_align_c = RoIAlign(output_size=(4, 3)) data = paddle.to_tensor(self.data) diff --git a/test/legacy_test/test_ops_roi_pool.py b/test/legacy_test/test_ops_roi_pool.py index cad43d2c4ff6a8..7460b15c7ea432 100644 --- a/test/legacy_test/test_ops_roi_pool.py +++ b/test/legacy_test/test_ops_roi_pool.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.pir_utils import test_with_pir_api from paddle.vision.ops import RoIPool, roi_pool @@ -81,6 +82,7 @@ def test_roi_pool_functional_dynamic(self): self.roi_pool_functional(3) self.roi_pool_functional(output_size=(3, 4)) + @test_with_pir_api def test_roi_pool_functional_static(self): paddle.enable_static() self.roi_pool_functional(3)