From 62d0a98a0595ee2198d49d8165e7fdd748d4c628 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Aug 2023 11:01:50 +0100 Subject: [PATCH 1/8] Simplify query_bounding_boxes logic --- test/test_datapoints.py | 8 +++++++- torchvision/datapoints/_bounding_box.py | 10 ++++++++++ torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/transforms/v2/_geometry.py | 4 ++-- torchvision/transforms/v2/_misc.py | 7 ++----- torchvision/transforms/v2/utils.py | 13 ++++++------- 6 files changed, 29 insertions(+), 17 deletions(-) diff --git a/test/test_datapoints.py b/test/test_datapoints.py index f0a44ec1720..bdc65506554 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -22,7 +22,7 @@ def test_mask_instance(data): assert mask.ndim == 3 and mask.shape[0] == 1 -@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]]) @pytest.mark.parametrize( "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] ) @@ -35,6 +35,12 @@ def test_bbox_instance(data, format): assert bboxes.format == format +def test_bbox_dim_error(): + data_3d = [[[1, 2, 3, 4]]] + with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"): + datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32)) + + @pytest.mark.parametrize( ("data", "input_requires_grad", "expected_requires_grad"), [ diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 780a950403c..d08c9cfc443 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -27,6 +27,12 @@ class BoundingBoxFormat(Enum): class BoundingBoxes(Datapoint): """[BETA] :class:`torch.Tensor` subclass for bounding boxes. + .. note:: + There should be only one :class:`~torchvision.datapoints.BoundingBoxes` + instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``, + although one :class:`~torchvision.datapoints.BoundingBoxes` object can + contain multiple bounding boxes. + Args: data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. format (BoundingBoxFormat, str): Format of the bounding box. @@ -44,6 +50,10 @@ class BoundingBoxes(Datapoint): @classmethod def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.ndim != 2: + raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 28aff8416d2..7f4ecadf7fa 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size +from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size class FixedSizeCrop(Transform): @@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bounding_boxes: Optional[torch.Tensor] try: - bounding_boxes = query_bounding_boxes(flat_inputs) + bounding_boxes = get_bounding_boxes(flat_inputs) except ValueError: bounding_boxes = None diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 9e7ca64d41c..c95e817d1b9 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -22,7 +22,7 @@ _setup_float_or_seq, _setup_size, ) -from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size +from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size class RandomHorizontalFlip(_RandomApplyTransform): @@ -1165,7 +1165,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - bboxes = query_bounding_boxes(flat_inputs) + bboxes = get_bounding_boxes(flat_inputs) while True: # sample an option diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index a799070ee1e..ce6df0ec855 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -10,7 +10,7 @@ from torchvision.transforms.v2 import functional as F, Transform from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size -from .utils import has_any, is_simple_tensor, query_bounding_boxes +from .utils import get_bounding_boxes, has_any, is_simple_tensor # TODO: do we want/need to expose this? @@ -384,10 +384,7 @@ def forward(self, *inputs: Any) -> Any: ) flat_inputs, spec = tree_flatten(inputs) - # TODO: this enforces one single BoundingBoxes entry. - # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... - # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? - boxes = query_bounding_boxes(flat_inputs) + boxes = get_bounding_boxes(flat_inputs) if boxes.ndim != 2: raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py index dd9f4489dee..1d9219fb4f5 100644 --- a/torchvision/transforms/v2/utils.py +++ b/torchvision/transforms/v2/utils.py @@ -9,13 +9,12 @@ from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor -def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: - bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)] - if not bounding_boxes: - raise TypeError("No bounding boxes were found in the sample") - elif len(bounding_boxes) > 1: - raise ValueError("Found multiple bounding boxes instances in the sample") - return bounding_boxes.pop() +def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: + # This assumes there is only one bbox per sample as per the general convention + try: + return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) + except StopIteration: + raise ValueError("No bounding boxes were found in the sample") def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: From b7dc5beaa32dd1acb173c75286f5cf02ac956335 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Aug 2023 12:08:28 +0100 Subject: [PATCH 2/8] Nasty fix --- test/common_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/common_utils.py b/test/common_utils.py index b5edda3edb2..1709a6a0cfc 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -638,6 +638,7 @@ def make_bounding_box( dtype=None, device="cpu", ): + batch_dims = () # This is nasty but this whole thing will be removed soon. def sample_position(values, max_value): # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # However, if we have batch_dims, we need tensors as limits. From 0676b086e8295e9fe15a4e71c2f82598973589f9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Aug 2023 12:12:06 +0100 Subject: [PATCH 3/8] lint --- test/common_utils.py | 1 + torchvision/prototype/transforms/_geometry.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/common_utils.py b/test/common_utils.py index 1709a6a0cfc..495eb783abb 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -639,6 +639,7 @@ def make_bounding_box( device="cpu", ): batch_dims = () # This is nasty but this whole thing will be removed soon. + def sample_position(values, max_value): # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # However, if we have batch_dims, we need tensors as limits. diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 90a4c5f9289..c8cc99cb310 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size +from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size, get_bounding_boxes class FixedSizeCrop(Transform): From c5190d5d1a1225c3b764af4edfe0a7b54d09ed6c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Aug 2023 10:12:08 +0100 Subject: [PATCH 4/8] Tired of these tests. Removing --- test/common_utils.py | 3 +- test/test_prototype_transforms.py | 29 ------------------- torchvision/prototype/transforms/_geometry.py | 2 +- 3 files changed, 3 insertions(+), 31 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 7680d096072..f648867266e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -649,7 +649,8 @@ def make_bounding_box( dtype=None, device="cpu", ): - batch_dims = () # This is nasty but this whole thing will be removed soon. + if len(batch_dims) > 1: # This is nasty but this whole thing will be removed soon. + batch_dims = batch_dims[:-1] def sample_position(values, max_value): # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b1760f6f965..01f390ecb68 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -286,35 +286,6 @@ def test__transform_culling(self, mocker): assert_equal(output["masks"], masks[is_valid]) assert_equal(output["labels"], labels[is_valid]) - def test__transform_bounding_boxes_clamping(self, mocker): - batch_size = 3 - canvas_size = (10, 10) - - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=True, - top=0, - left=0, - height=canvas_size[0], - width=canvas_size[1], - is_valid=torch.full((batch_size,), fill_value=True), - needs_pad=False, - ), - ) - - bounding_boxes = make_bounding_box( - format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) - ) - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes") - - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - transform(bounding_boxes) - - mock.assert_called_once() - class TestLabelToOneHot: def test__transform(self): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c8cc99cb310..e3819554d0b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size, get_bounding_boxes +from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size class FixedSizeCrop(Transform): From d360904070acd88e55ba2afe370987dc74b10802 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 22:38:39 +0200 Subject: [PATCH 5/8] fix bad stacking logic in bounding boxes references --- test/common_utils.py | 7 +-- test/test_transforms_v2_functional.py | 78 ++++++++++----------------- test/transforms_v2_kernel_infos.py | 13 ++--- 3 files changed, 33 insertions(+), 65 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index f648867266e..5a7daa58c32 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -649,9 +649,6 @@ def make_bounding_box( dtype=None, device="cpu", ): - if len(batch_dims) > 1: # This is nasty but this whole thing will be removed soon. - batch_dims = batch_dims[:-1] - def sample_position(values, max_value): # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # However, if we have batch_dims, we need tensors as limits. @@ -705,12 +702,12 @@ def fn(shape, dtype, device): format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device ) - return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size) + return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=canvas_size) def make_bounding_box_loaders( *, - extra_dims=DEFAULT_EXTRA_DIMS, + extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2), formats=tuple(datapoints.BoundingBoxFormat), canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtypes=(torch.float32, torch.float64, torch.int64), diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 8d529732610..339a09853a9 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -763,21 +763,20 @@ def _parse_padding(padding): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) def test_correctness_pad_bounding_boxes(device, padding): - def _compute_expected_bbox(bbox, padding_): + def _compute_expected_bbox(bbox, format, padding_): pad_left, pad_up, _, _ = _parse_padding(padding_) dtype = bbox.dtype - format = bbox.format bbox = ( bbox.clone() if format == datapoints.BoundingBoxFormat.XYXY - else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_boxes(bbox, new_format=format) + bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format) if bbox.dtype != dtype: # Temporary cast to original dtype # e.g. float32 -> int @@ -789,7 +788,7 @@ def _compute_expected_canvas_size(bbox, padding_): height, width = bbox.canvas_size return height + pad_up + pad_down, width + pad_left + pad_right - for bboxes in make_bounding_boxes(): + for bboxes in make_bounding_boxes(extra_dims=((4,),)): bboxes = bboxes.to(device) bboxes_format = bboxes.format bboxes_canvas_size = bboxes.canvas_size @@ -800,18 +799,10 @@ def _compute_expected_canvas_size(bbox, padding_): torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) - if bboxes.ndim < 2 or bboxes.shape[0] == 0: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, padding)) + expected_bboxes = torch.stack( + [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()] + ).reshape(bboxes.shape) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) @@ -836,7 +827,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): ], ) def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): - def _compute_expected_bbox(bbox, pcoeffs_): + def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_): m1 = np.array( [ [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], @@ -850,7 +841,9 @@ def _compute_expected_bbox(bbox, pcoeffs_): ] ) - bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + bbox_xyxy = convert_format_bounding_boxes( + bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY + ) points = np.array( [ [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], @@ -870,14 +863,11 @@ def _compute_expected_bbox(bbox, pcoeffs_): np.max(transformed_points[:, 1]), ] ) - out_bbox = datapoints.BoundingBoxes( - out_bbox, - format=datapoints.BoundingBoxFormat.XYXY, - canvas_size=bbox.canvas_size, - dtype=bbox.dtype, - device=bbox.device, + out_bbox = torch.from_numpy(out_bbox) + out_bbox = convert_format_bounding_boxes( + out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_ ) - return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format)) + return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox) canvas_size = (32, 38) @@ -896,17 +886,13 @@ def _compute_expected_bbox(bbox, pcoeffs_): coefficients=pcoeffs, ) - if bboxes.ndim < 2: - bboxes = [bboxes] + expected_bboxes = torch.stack( + [ + _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs) + for b in bboxes.reshape(-1, 4).unbind() + ] + ).reshape(bboxes.shape) - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1) @@ -916,9 +902,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): [(18, 18), [18, 15], (16, 19), [12], [46, 48]], ) def test_correctness_center_crop_bounding_boxes(device, output_size): - def _compute_expected_bbox(bbox, output_size_): - format_ = bbox.format - canvas_size_ = bbox.canvas_size + def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): dtype = bbox.dtype bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) @@ -947,18 +931,12 @@ def _compute_expected_bbox(bbox, output_size_): bboxes, bboxes_format, bboxes_canvas_size, output_size ) - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) - - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] + expected_bboxes = torch.stack( + [ + _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size) + for b in bboxes.reshape(-1, 4).unbind() + ] + ).reshape(bboxes.shape) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_canvas_size, output_size) diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 01605f696b4..ac5651d3217 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -222,16 +222,9 @@ def transform(bbox, affine_matrix_, format_, canvas_size_): out_bbox = out_bbox.to(dtype=in_dtype) return out_bbox - if bounding_boxes.ndim < 2: - bounding_boxes = [bounding_boxes] - - expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes] - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - - return expected_bboxes + return torch.stack( + [transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()] + ).reshape(bounding_boxes.shape) def sample_inputs_convert_format_bounding_boxes(): From 0619801b1774434ad724ae8fba42143be850d17e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 23:00:03 +0200 Subject: [PATCH 6/8] cleanup --- torchvision/datapoints/_bounding_box.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index c972dd8469c..d459a55448a 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -48,17 +48,13 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap( - cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int] - ) -> BoundingBoxes: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] if tensor.ndim == 1: tensor = tensor.unsqueeze(0) elif tensor.ndim != 2: raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") - if isinstance(format, str): format = BoundingBoxFormat[format.upper()] - bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size From 0189a6bb8ded08b86d74f6d09c974f2ec23f09a2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 23:05:42 +0200 Subject: [PATCH 7/8] put back and fix removed test --- test/test_prototype_transforms.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 01f390ecb68..d395c224785 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -20,7 +20,7 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms from torchvision.transforms.v2._utils import _convert_fill_arg -from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil +from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.utils import check_type, is_simple_tensor BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -286,6 +286,37 @@ def test__transform_culling(self, mocker): assert_equal(output["masks"], masks[is_valid]) assert_equal(output["labels"], labels[is_valid]) + def test__transform_bounding_boxes_clamping(self, mocker): + batch_size = 3 + canvas_size = (10, 10) + + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=True, + top=0, + left=0, + height=canvas_size[0], + width=canvas_size[1], + is_valid=torch.full((batch_size,), fill_value=True), + needs_pad=False, + ), + ) + + bounding_boxes = make_bounding_box( + format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) + ) + mock = mocker.patch( + "torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes + ) + + transform = transforms.FixedSizeCrop((-1, -1)) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + + transform(bounding_boxes) + + mock.assert_called_once() + class TestLabelToOneHot: def test__transform(self): From b962a4b2c4475f3d9a57e1e134ad77bada42af86 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 23:09:21 +0200 Subject: [PATCH 8/8] remove obsolete shape check in SanitizeBoundingBoxes --- test/test_transforms_v2.py | 12 ------------ torchvision/transforms/v2/_misc.py | 3 --- 2 files changed, 15 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 49455b05dc5..353cc846bed 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors(): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) - with pytest.raises(ValueError, match="boxes must be of shape"): - bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements - [ - [[0, 0, 10, 10]], - [[0, 0, 10, 10]], - ], - format=datapoints.BoundingBoxFormat.XYXY, - canvas_size=(20, 20), - ) - different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} - transforms.SanitizeBoundingBoxes()(different_sizes) - @pytest.mark.parametrize( "import_statement", diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index ce6df0ec855..d2dddd96d5c 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -386,9 +386,6 @@ def forward(self, *inputs: Any) -> Any: flat_inputs, spec = tree_flatten(inputs) boxes = get_bounding_boxes(flat_inputs) - if boxes.ndim != 2: - raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") - if labels is not None and boxes.shape[0] != labels.shape[0]: raise ValueError( f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."