Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def test_detach_wrapping():
assert type(image_detached) is datapoints.Image


def test_no_wrapping_exceptions_with_metadata():
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
Comment on lines +116 to +117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're wondering why I added this test, it's because I originally had changed the _NO_WRAPPING_EXCEPTIONS to return output.as_subclass(cls) instead of cls.wrap_like(input, output), and I didn't get a proper error in the tests. The only tests that were failing here were test_deepcopy() which was failing on print(bbox), which is a bit too "remote" of a test.

format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)

bbox = bbox.clone()
assert bbox.format, bbox.canvas_size == (format, canvas_size)

bbox = bbox.to(torch.float64)
assert bbox.format, bbox.canvas_size == (format, canvas_size)

bbox = bbox.detach()
assert bbox.format, bbox.canvas_size == (format, canvas_size)

assert not bbox.requires_grad
bbox.requires_grad_(True)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad


def test_other_op_no_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))

Expand Down
13 changes: 4 additions & 9 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int]

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Datapoint._wrap is removed now, do we still need the mypy directive?

Suggested change
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes:

if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size
Expand All @@ -59,10 +61,6 @@ def __new__(
requires_grad: Optional[bool] = None,
) -> BoundingBoxes:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)

if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]

return cls._wrap(tensor, format=format, canvas_size=canvas_size)

@classmethod
Expand All @@ -71,7 +69,7 @@ def wrap_like(
other: BoundingBoxes,
tensor: torch.Tensor,
*,
format: Optional[BoundingBoxFormat] = None,
format: Optional[Union[BoundingBoxFormat, str]] = None,
canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Expand All @@ -85,9 +83,6 @@ def wrap_like(
omitted, it is taken from the reference.

"""
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]

return cls._wrap(
tensor,
format=format if format is not None else other.format,
Expand Down
6 changes: 1 addition & 5 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ def _to_tensor(
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor)
return tensor.as_subclass(cls)

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __new__(
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)

return cls._wrap(tensor)
return tensor.as_subclass(cls)

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def __new__(
data = F.pil_to_tensor(data)

tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)
return tensor.as_subclass(cls)
2 changes: 1 addition & 1 deletion torchvision/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __new__(
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4:
raise ValueError
return cls._wrap(tensor)
return tensor.as_subclass(cls)

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datapoints/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
categories: Optional[Sequence[str]]

@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override]
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
Expand Down