Skip to content
6 changes: 3 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

spatial_size = _parse_size(spatial_size, name="canvas_size")
spatial_size = _parse_size(spatial_size, name="spatial_size")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby


def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
Expand All @@ -702,12 +702,12 @@ def fn(shape, dtype, device):
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
)

return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=spatial_size)


def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
formats=tuple(datapoints.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
Expand Down
8 changes: 7 additions & 1 deletion test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_mask_instance(data):
assert mask.ndim == 3 and mask.shape[0] == 1


@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
Expand All @@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
assert bboxes.format == format


def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))


@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
Expand Down
6 changes: 4 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil
from torchvision.transforms.v2.utils import check_type, is_simple_tensor

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -306,7 +306,9 @@ def test__transform_bounding_boxes_clamping(self, mocker):
bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")
mock = mocker.patch(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
)

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
Expand Down
12 changes: 0 additions & 12 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors():
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)

with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements
[
[[0, 0, 10, 10]],
[[0, 0, 10, 10]],
],
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)


@pytest.mark.parametrize(
"import_statement",
Expand Down
78 changes: 28 additions & 50 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,21 +711,20 @@ def _parse_padding(padding):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
def test_correctness_pad_bounding_boxes(device, padding):
def _compute_expected_bbox(bbox, padding_):
def _compute_expected_bbox(bbox, format, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_)

dtype = bbox.dtype
format = bbox.format
bbox = (
bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
)

bbox[0::2] += pad_left
bbox[1::2] += pad_up

bbox = convert_format_bounding_boxes(bbox, new_format=format)
bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
Expand All @@ -737,7 +736,7 @@ def _compute_expected_canvas_size(bbox, padding_):
height, width = bbox.canvas_size
return height + pad_up + pad_down, width + pad_left + pad_right

for bboxes in make_bounding_boxes():
for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_canvas_size = bboxes.canvas_size
Expand All @@ -748,18 +747,10 @@ def _compute_expected_canvas_size(bbox, padding_):

torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))

if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding))
expected_bboxes = torch.stack(
[_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
).reshape(bboxes.shape)

if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)


Expand All @@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
],
)
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def _compute_expected_bbox(bbox, pcoeffs_):
def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
m1 = np.array(
[
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
Expand All @@ -798,7 +789,9 @@ def _compute_expected_bbox(bbox, pcoeffs_):
]
)

bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox_xyxy = convert_format_bounding_boxes(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
Expand All @@ -818,14 +811,11 @@ def _compute_expected_bbox(bbox, pcoeffs_):
np.max(transformed_points[:, 1]),
]
)
out_bbox = datapoints.BoundingBoxes(
out_bbox,
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=bbox.canvas_size,
dtype=bbox.dtype,
device=bbox.device,
out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_format_bounding_boxes(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
)
return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)

canvas_size = (32, 38)

Expand All @@ -844,17 +834,13 @@ def _compute_expected_bbox(bbox, pcoeffs_):
coefficients=pcoeffs,
)

if bboxes.ndim < 2:
bboxes = [bboxes]
expected_bboxes = torch.stack(
[
_compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)

expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)


Expand All @@ -864,9 +850,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
[(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format
canvas_size_ = bbox.canvas_size
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)

Expand Down Expand Up @@ -895,18 +879,12 @@ def _compute_expected_bbox(bbox, output_size_):
bboxes, bboxes_format, bboxes_canvas_size, output_size
)

if bboxes.ndim < 2:
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size))

if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
expected_bboxes = torch.stack(
[
_compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)

torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_canvas_size, output_size)
Expand Down
13 changes: 3 additions & 10 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,9 @@ def transform(bbox, affine_matrix_, format_, canvas_size_):
out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox

if bounding_boxes.ndim < 2:
bounding_boxes = [bounding_boxes]

expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]

return expected_bboxes
return torch.stack(
[transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()]
).reshape(bounding_boxes.shape)


def sample_inputs_convert_format_bounding_boxes():
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
class BoundingBoxes(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.

.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
contain multiple bounding boxes.

Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box.
Expand All @@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
Comment on lines +52 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a super strong opinion on that and I don't mind removing it, but IMO passing other number of dims indicates a user mistake upstream.

We don't support batches of bounding boxes (i.e. we can't batch BoundingBoxes objects where num_bboxes differs per object) so I'm not sure we should pretend that we do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this, the following check is obsolete:

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")

I've removed it and the (now failing) test for it in b962a4b. If we decide to not go with this, make sure to revert the commit as well.

if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size


class FixedSizeCrop(Transform):
Expand Down Expand Up @@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = query_bounding_boxes(flat_inputs)
bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError:
bounding_boxes = None

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs)
bboxes = get_bounding_boxes(flat_inputs)

while True:
# sample an option
Expand Down
10 changes: 2 additions & 8 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_boxes
from .utils import get_bounding_boxes, has_any, is_simple_tensor


# TODO: do we want/need to expose this?
Expand Down Expand Up @@ -384,13 +384,7 @@ def forward(self, *inputs: Any) -> Any:
)

flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBoxes entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_boxes(flat_inputs)

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
boxes = get_bounding_boxes(flat_inputs)

if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
Expand Down
13 changes: 6 additions & 7 deletions torchvision/transforms/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor


def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
if not bounding_boxes:
raise TypeError("No bounding boxes were found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes instances in the sample")
return bounding_boxes.pop()
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
Expand Down