Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from collections import OrderedDict

import numpy as np

Expand Down Expand Up @@ -1789,3 +1790,56 @@ def test__transform(self, mocker):
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)


@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.float64},
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
),
# this makes sure that plain tensors are not touched if we don't specify them
(
{features.Image: torch.float32, features.BoundingBox: torch.float64},
{torch.Tensor: torch.int64, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
# this makes sure the order of the dtype keys only makes a difference if no exact type match is found
(
OrderedDict([(torch.Tensor, torch.float64), (features.Image, torch.float32)]),
{torch.Tensor: torch.float64, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
# same as above, but relying on the insertion ordering of plain dicts
(
{torch.Tensor: torch.float64, features.Image: torch.float32},
{torch.Tensor: torch.float64, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)

if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
27 changes: 19 additions & 8 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Callable, Dict, Sequence, Type, Union
from collections import OrderedDict
from typing import Any, Callable, Dict, OrderedDict as OrderedDictType, Sequence, Type, Union

import PIL.Image

Expand Down Expand Up @@ -140,14 +140,25 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)


# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
class ToDtype(Transform):
def __init__(self, dtype: Union[torch.dtype, OrderedDictType[Type, torch.dtype]]) -> None:
super().__init__()
if isinstance(dtype, torch.dtype):
dtype = OrderedDict([(torch.Tensor, dtype)])
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))

def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
input_type = type(inpt)
if input_type in self.dtype:
dtype = self.dtype[input_type]
else:
for to_type, dtype in self.dtype.items():
if issubclass(input_type, to_type):
break
else:
return inpt

return inpt.to(dtype)


class RemoveSmallBoundingBoxes(Transform):
Expand Down