Skip to content

Commit 8faa1b1

Browse files
NicolasHugpmeier
andauthored
Simplify query_bounding_boxes logic (#7786)
Co-authored-by: Philip Meier <[email protected]>
1 parent 9b82df4 commit 8faa1b1

File tree

11 files changed

+67
-97
lines changed

11 files changed

+67
-97
lines changed

test/common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
691691
if isinstance(format, str):
692692
format = datapoints.BoundingBoxFormat[format]
693693

694-
spatial_size = _parse_size(spatial_size, name="canvas_size")
694+
spatial_size = _parse_size(spatial_size, name="spatial_size")
695695

696696
def fn(shape, dtype, device):
697697
*batch_dims, num_coordinates = shape
@@ -702,12 +702,12 @@ def fn(shape, dtype, device):
702702
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
703703
)
704704

705-
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
705+
return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=spatial_size)
706706

707707

708708
def make_bounding_box_loaders(
709709
*,
710-
extra_dims=DEFAULT_EXTRA_DIMS,
710+
extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
711711
formats=tuple(datapoints.BoundingBoxFormat),
712712
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
713713
dtypes=(torch.float32, torch.float64, torch.int64),

test/test_datapoints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_mask_instance(data):
2222
assert mask.ndim == 3 and mask.shape[0] == 1
2323

2424

25-
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
25+
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
2626
@pytest.mark.parametrize(
2727
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
2828
)
@@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
3535
assert bboxes.format == format
3636

3737

38+
def test_bbox_dim_error():
39+
data_3d = [[[1, 2, 3, 4]]]
40+
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
41+
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
42+
43+
3844
@pytest.mark.parametrize(
3945
("data", "input_requires_grad", "expected_requires_grad"),
4046
[

test/test_prototype_transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
2121
from torchvision.prototype import datapoints, transforms
2222
from torchvision.transforms.v2._utils import _convert_fill_arg
23-
from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil
23+
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil
2424
from torchvision.transforms.v2.utils import check_type, is_simple_tensor
2525

2626
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
@@ -306,7 +306,9 @@ def test__transform_bounding_boxes_clamping(self, mocker):
306306
bounding_boxes = make_bounding_box(
307307
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
308308
)
309-
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")
309+
mock = mocker.patch(
310+
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
311+
)
310312

311313
transform = transforms.FixedSizeCrop((-1, -1))
312314
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

test/test_transforms_v2.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors():
16541654
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
16551655
transforms.SanitizeBoundingBoxes()(different_sizes)
16561656

1657-
with pytest.raises(ValueError, match="boxes must be of shape"):
1658-
bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements
1659-
[
1660-
[[0, 0, 10, 10]],
1661-
[[0, 0, 10, 10]],
1662-
],
1663-
format=datapoints.BoundingBoxFormat.XYXY,
1664-
canvas_size=(20, 20),
1665-
)
1666-
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
1667-
transforms.SanitizeBoundingBoxes()(different_sizes)
1668-
16691657

16701658
@pytest.mark.parametrize(
16711659
"import_statement",

test/test_transforms_v2_functional.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -711,21 +711,20 @@ def _parse_padding(padding):
711711
@pytest.mark.parametrize("device", cpu_and_cuda())
712712
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
713713
def test_correctness_pad_bounding_boxes(device, padding):
714-
def _compute_expected_bbox(bbox, padding_):
714+
def _compute_expected_bbox(bbox, format, padding_):
715715
pad_left, pad_up, _, _ = _parse_padding(padding_)
716716

717717
dtype = bbox.dtype
718-
format = bbox.format
719718
bbox = (
720719
bbox.clone()
721720
if format == datapoints.BoundingBoxFormat.XYXY
722-
else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
721+
else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
723722
)
724723

725724
bbox[0::2] += pad_left
726725
bbox[1::2] += pad_up
727726

728-
bbox = convert_format_bounding_boxes(bbox, new_format=format)
727+
bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
729728
if bbox.dtype != dtype:
730729
# Temporary cast to original dtype
731730
# e.g. float32 -> int
@@ -737,7 +736,7 @@ def _compute_expected_canvas_size(bbox, padding_):
737736
height, width = bbox.canvas_size
738737
return height + pad_up + pad_down, width + pad_left + pad_right
739738

740-
for bboxes in make_bounding_boxes():
739+
for bboxes in make_bounding_boxes(extra_dims=((4,),)):
741740
bboxes = bboxes.to(device)
742741
bboxes_format = bboxes.format
743742
bboxes_canvas_size = bboxes.canvas_size
@@ -748,18 +747,10 @@ def _compute_expected_canvas_size(bbox, padding_):
748747

749748
torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
750749

751-
if bboxes.ndim < 2 or bboxes.shape[0] == 0:
752-
bboxes = [bboxes]
753-
754-
expected_bboxes = []
755-
for bbox in bboxes:
756-
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
757-
expected_bboxes.append(_compute_expected_bbox(bbox, padding))
750+
expected_bboxes = torch.stack(
751+
[_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
752+
).reshape(bboxes.shape)
758753

759-
if len(expected_bboxes) > 1:
760-
expected_bboxes = torch.stack(expected_bboxes)
761-
else:
762-
expected_bboxes = expected_bboxes[0]
763754
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
764755

765756

@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
784775
],
785776
)
786777
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
787-
def _compute_expected_bbox(bbox, pcoeffs_):
778+
def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
788779
m1 = np.array(
789780
[
790781
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
@@ -798,7 +789,9 @@ def _compute_expected_bbox(bbox, pcoeffs_):
798789
]
799790
)
800791

801-
bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
792+
bbox_xyxy = convert_format_bounding_boxes(
793+
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
794+
)
802795
points = np.array(
803796
[
804797
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
@@ -818,14 +811,11 @@ def _compute_expected_bbox(bbox, pcoeffs_):
818811
np.max(transformed_points[:, 1]),
819812
]
820813
)
821-
out_bbox = datapoints.BoundingBoxes(
822-
out_bbox,
823-
format=datapoints.BoundingBoxFormat.XYXY,
824-
canvas_size=bbox.canvas_size,
825-
dtype=bbox.dtype,
826-
device=bbox.device,
814+
out_bbox = torch.from_numpy(out_bbox)
815+
out_bbox = convert_format_bounding_boxes(
816+
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
827817
)
828-
return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
818+
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
829819

830820
canvas_size = (32, 38)
831821

@@ -844,17 +834,13 @@ def _compute_expected_bbox(bbox, pcoeffs_):
844834
coefficients=pcoeffs,
845835
)
846836

847-
if bboxes.ndim < 2:
848-
bboxes = [bboxes]
837+
expected_bboxes = torch.stack(
838+
[
839+
_compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
840+
for b in bboxes.reshape(-1, 4).unbind()
841+
]
842+
).reshape(bboxes.shape)
849843

850-
expected_bboxes = []
851-
for bbox in bboxes:
852-
bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
853-
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
854-
if len(expected_bboxes) > 1:
855-
expected_bboxes = torch.stack(expected_bboxes)
856-
else:
857-
expected_bboxes = expected_bboxes[0]
858844
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
859845

860846

@@ -864,9 +850,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
864850
[(18, 18), [18, 15], (16, 19), [12], [46, 48]],
865851
)
866852
def test_correctness_center_crop_bounding_boxes(device, output_size):
867-
def _compute_expected_bbox(bbox, output_size_):
868-
format_ = bbox.format
869-
canvas_size_ = bbox.canvas_size
853+
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
870854
dtype = bbox.dtype
871855
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
872856

@@ -895,18 +879,12 @@ def _compute_expected_bbox(bbox, output_size_):
895879
bboxes, bboxes_format, bboxes_canvas_size, output_size
896880
)
897881

898-
if bboxes.ndim < 2:
899-
bboxes = [bboxes]
900-
901-
expected_bboxes = []
902-
for bbox in bboxes:
903-
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
904-
expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
905-
906-
if len(expected_bboxes) > 1:
907-
expected_bboxes = torch.stack(expected_bboxes)
908-
else:
909-
expected_bboxes = expected_bboxes[0]
882+
expected_bboxes = torch.stack(
883+
[
884+
_compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
885+
for b in bboxes.reshape(-1, 4).unbind()
886+
]
887+
).reshape(bboxes.shape)
910888

911889
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
912890
torch.testing.assert_close(output_canvas_size, output_size)

test/transforms_v2_kernel_infos.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,9 @@ def transform(bbox, affine_matrix_, format_, canvas_size_):
222222
out_bbox = out_bbox.to(dtype=in_dtype)
223223
return out_bbox
224224

225-
if bounding_boxes.ndim < 2:
226-
bounding_boxes = [bounding_boxes]
227-
228-
expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes]
229-
if len(expected_bboxes) > 1:
230-
expected_bboxes = torch.stack(expected_bboxes)
231-
else:
232-
expected_bboxes = expected_bboxes[0]
233-
234-
return expected_bboxes
225+
return torch.stack(
226+
[transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()]
227+
).reshape(bounding_boxes.shape)
235228

236229

237230
def sample_inputs_convert_format_bounding_boxes():

torchvision/datapoints/_bounding_box.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
2626
class BoundingBoxes(Datapoint):
2727
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
2828
29+
.. note::
30+
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
31+
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
32+
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
33+
contain multiple bounding boxes.
34+
2935
Args:
3036
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
3137
format (BoundingBoxFormat, str): Format of the bounding box.
@@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):
4349

4450
@classmethod
4551
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
52+
if tensor.ndim == 1:
53+
tensor = tensor.unsqueeze(0)
54+
elif tensor.ndim != 2:
55+
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
4656
if isinstance(format, str):
4757
format = BoundingBoxFormat[format.upper()]
4858
bounding_boxes = tensor.as_subclass(cls)

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype.datapoints import Label, OneHotLabel
88
from torchvision.transforms.v2 import functional as F, Transform
99
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
10-
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
10+
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
1111

1212

1313
class FixedSizeCrop(Transform):
@@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
6161

6262
bounding_boxes: Optional[torch.Tensor]
6363
try:
64-
bounding_boxes = query_bounding_boxes(flat_inputs)
64+
bounding_boxes = get_bounding_boxes(flat_inputs)
6565
except ValueError:
6666
bounding_boxes = None
6767

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_setup_float_or_seq,
2424
_setup_size,
2525
)
26-
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
26+
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size
2727

2828

2929
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -1137,7 +1137,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
11371137

11381138
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
11391139
orig_h, orig_w = query_size(flat_inputs)
1140-
bboxes = query_bounding_boxes(flat_inputs)
1140+
bboxes = get_bounding_boxes(flat_inputs)
11411141

11421142
while True:
11431143
# sample an option

torchvision/transforms/v2/_misc.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchvision.transforms.v2 import functional as F, Transform
1111

1212
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
13-
from .utils import has_any, is_simple_tensor, query_bounding_boxes
13+
from .utils import get_bounding_boxes, has_any, is_simple_tensor
1414

1515

1616
# TODO: do we want/need to expose this?
@@ -384,13 +384,7 @@ def forward(self, *inputs: Any) -> Any:
384384
)
385385

386386
flat_inputs, spec = tree_flatten(inputs)
387-
# TODO: this enforces one single BoundingBoxes entry.
388-
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
389-
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
390-
boxes = query_bounding_boxes(flat_inputs)
391-
392-
if boxes.ndim != 2:
393-
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
387+
boxes = get_bounding_boxes(flat_inputs)
394388

395389
if labels is not None and boxes.shape[0] != labels.shape[0]:
396390
raise ValueError(

0 commit comments

Comments
 (0)