Skip to content
3 changes: 3 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ def make_bounding_box(
dtype=None,
device="cpu",
):
if len(batch_dims) > 1: # This is nasty but this whole thing will be removed soon.
batch_dims = batch_dims[:-1]

def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
Expand Down
8 changes: 7 additions & 1 deletion test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_mask_instance(data):
assert mask.ndim == 3 and mask.shape[0] == 1


@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
Expand All @@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
assert bboxes.format == format


def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))


@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
Expand Down
29 changes: 0 additions & 29 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,35 +286,6 @@ def test__transform_culling(self, mocker):
assert_equal(output["masks"], masks[is_valid])
assert_equal(output["labels"], labels[is_valid])

def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=canvas_size[0],
width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

transform(bounding_boxes)

mock.assert_called_once()


class TestLabelToOneHot:
def test__transform(self):
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
class BoundingBoxes(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.

.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
contain multiple bounding boxes.

Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box.
Expand All @@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes:
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
Comment on lines +52 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a super strong opinion on that and I don't mind removing it, but IMO passing other number of dims indicates a user mistake upstream.

We don't support batches of bounding boxes (i.e. we can't batch BoundingBoxes objects where num_bboxes differs per object) so I'm not sure we should pretend that we do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this, the following check is obsolete:

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")

I've removed it and the (now failing) test for it in b962a4b. If we decide to not go with this, make sure to revert the commit as well.

bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size


class FixedSizeCrop(Transform):
Expand Down Expand Up @@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = query_bounding_boxes(flat_inputs)
bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError:
bounding_boxes = None

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs)
bboxes = get_bounding_boxes(flat_inputs)

while True:
# sample an option
Expand Down
7 changes: 2 additions & 5 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_boxes
from .utils import get_bounding_boxes, has_any, is_simple_tensor


# TODO: do we want/need to expose this?
Expand Down Expand Up @@ -384,10 +384,7 @@ def forward(self, *inputs: Any) -> Any:
)

flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBoxes entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_boxes(flat_inputs)
boxes = get_bounding_boxes(flat_inputs)

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
Expand Down
13 changes: 6 additions & 7 deletions torchvision/transforms/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor


def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
if not bounding_boxes:
raise TypeError("No bounding boxes were found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes instances in the sample")
return bounding_boxes.pop()
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
Expand Down