diff --git a/test/test_datapoints.py b/test/test_datapoints.py index f0a44ec1720..25a2182e050 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -113,6 +113,26 @@ def test_detach_wrapping(): assert type(image_detached) is datapoints.Image +def test_no_wrapping_exceptions_with_metadata(): + # Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata + format, canvas_size = "XYXY", (32, 32) + bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) + + bbox = bbox.clone() + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + bbox = bbox.to(torch.float64) + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + bbox = bbox.detach() + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + assert not bbox.requires_grad + bbox.requires_grad_(True) + assert bbox.format, bbox.canvas_size == (format, canvas_size) + assert bbox.requires_grad + + def test_other_op_no_wrapping(): image = datapoints.Image(torch.rand(3, 16, 16)) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 7477b3652dc..9677cef21e6 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] + if isinstance(format, str): + format = BoundingBoxFormat[format.upper()] bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size @@ -59,10 +61,6 @@ def __new__( requires_grad: Optional[bool] = None, ) -> BoundingBoxes: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - - if isinstance(format, str): - format = BoundingBoxFormat[format.upper()] - return cls._wrap(tensor, format=format, canvas_size=canvas_size) @classmethod @@ -71,7 +69,7 @@ def wrap_like( other: BoundingBoxes, tensor: torch.Tensor, *, - format: Optional[BoundingBoxFormat] = None, + format: Optional[Union[BoundingBoxFormat, str]] = None, canvas_size: Optional[Tuple[int, int]] = None, ) -> BoundingBoxes: """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. @@ -85,9 +83,6 @@ def wrap_like( omitted, it is taken from the reference. """ - if isinstance(format, str): - format = BoundingBoxFormat[format.upper()] - return cls._wrap( tensor, format=format if format is not None else other.format, diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index fae3c18656b..9b1c648648d 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -32,13 +32,9 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - @classmethod - def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) - @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - return cls._wrap(tensor) + return tensor.as_subclass(cls) _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index 9b635e8e034..cf7b8b1fccd 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -41,7 +41,7 @@ def __new__( elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - return cls._wrap(tensor) + return tensor.as_subclass(cls) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 95eda077929..e2bafcd6883 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -36,4 +36,4 @@ def __new__( data = F.pil_to_tensor(data) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor) + return tensor.as_subclass(cls) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 842c05bf7e9..19ab0aa8de7 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -31,7 +31,7 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: raise ValueError - return cls._wrap(tensor) + return tensor.as_subclass(cls) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index ac9b2d8912a..7ed2f7522b0 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -15,7 +15,7 @@ class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: label_base = tensor.as_subclass(cls) label_base.categories = categories return label_base