Skip to content
1 change: 1 addition & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def make_bounding_box(
dtype=None,
device="cpu",
):
batch_dims = () # This is nasty but this whole thing will be removed soon.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I had to add this to fix the tests. I'm sure there's a better way to fix it, but I don't want to spend time on this (already spent 10+ minutes) as these tests an helpers will be removed.

Feel free to push a better fix if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as these tests an helpers will be removed.

You put the hack into make_bounding_box, which will not be removed. I've moved it into the helpers that will actually be removed. Nothing new is failing after that. Meaning, we are not misusing make_bounding_box anywhere from the new side of the tests.

def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
Expand Down
8 changes: 7 additions & 1 deletion test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_mask_instance(data):
assert mask.ndim == 3 and mask.shape[0] == 1


@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
Expand All @@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
assert bboxes.format == format


def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))


@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class BoundingBoxFormat(Enum):
class BoundingBoxes(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.

.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
contain multiple bounding boxes.

Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box.
Expand All @@ -44,6 +50,10 @@ class BoundingBoxes(Datapoint):

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes:
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
Comment on lines +52 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a super strong opinion on that and I don't mind removing it, but IMO passing other number of dims indicates a user mistake upstream.

We don't support batches of bounding boxes (i.e. we can't batch BoundingBoxes objects where num_bboxes differs per object) so I'm not sure we should pretend that we do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this, the following check is obsolete:

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")

I've removed it and the (now failing) test for it in b962a4b. If we decide to not go with this, make sure to revert the commit as well.

bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size


class FixedSizeCrop(Transform):
Expand Down Expand Up @@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = query_bounding_boxes(flat_inputs)
bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError:
bounding_boxes = None

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs)
bboxes = get_bounding_boxes(flat_inputs)

while True:
# sample an option
Expand Down
7 changes: 2 additions & 5 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_boxes
from .utils import get_bounding_boxes, has_any, is_simple_tensor


# TODO: do we want/need to expose this?
Expand Down Expand Up @@ -384,10 +384,7 @@ def forward(self, *inputs: Any) -> Any:
)

flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBoxes entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_boxes(flat_inputs)
boxes = get_bounding_boxes(flat_inputs)

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
Expand Down
13 changes: 6 additions & 7 deletions torchvision/transforms/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor


def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
if not bounding_boxes:
raise TypeError("No bounding boxes were found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes instances in the sample")
return bounding_boxes.pop()
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
Expand Down