Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from paddle import _C_ops, _legacy_C_ops
from paddle.pir import OpResult
from paddle.tensor.math import _add_with_axis
from paddle.utils import convert_to_list

Expand Down Expand Up @@ -674,8 +675,8 @@ def box_coder(
... box_normalized=False)
...
"""
if in_dygraph_mode():
if isinstance(prior_box_var, core.eager.Tensor):
if in_dynamic_or_pir_mode():
if isinstance(prior_box_var, (core.eager.Tensor, OpResult)):
output_box = _C_ops.box_coder(
prior_box,
prior_box_var,
Expand Down
11 changes: 7 additions & 4 deletions test/legacy_test/test_box_coder_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0):
Expand Down Expand Up @@ -109,7 +110,7 @@ def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0):

class TestBoxCoderOp(OpTest):
def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def setUp(self):
self.op_type = "box_coder"
Expand Down Expand Up @@ -142,7 +143,7 @@ def setUp(self):

class TestBoxCoderOpWithoutBoxVar(OpTest):
def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def setUp(self):
self.python_api = paddle.vision.ops.box_coder
Expand Down Expand Up @@ -207,7 +208,7 @@ def setUp(self):

class TestBoxCoderOpWithAxis(OpTest):
def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def setUp(self):
self.python_api = paddle.vision.ops.box_coder
Expand Down Expand Up @@ -286,7 +287,7 @@ def wrapper_box_coder(

class TestBoxCoderOpWithVariance(OpTest):
def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def setUp(self):
self.op_type = "box_coder"
Expand Down Expand Up @@ -370,6 +371,7 @@ def setUp(self):
self.prior_box_var_np = np.random.random((80, 4)).astype('float32')
self.target_box_np = np.random.random((20, 80, 4)).astype('float32')

@test_with_pir_api
def test_dygraph_with_static(self):
paddle.enable_static()
exe = paddle.static.Executor()
Expand Down Expand Up @@ -428,6 +430,7 @@ def setUp(self):
self.prior_box_np = np.random.random((80, 4)).astype('float32')
self.target_box_np = np.random.random((20, 80, 4)).astype('float32')

@test_with_pir_api
def test_support_tuple(self):
paddle.enable_static()
exe = paddle.static.Executor()
Expand Down