Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
_transformed_types = (features.is_simple_tensor, features._Feature)

def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype
Expand All @@ -157,7 +157,10 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)])
dtype = self.dtype.get(type(inpt))
if dtype is None:
return inpt
return inpt.to(dtype=dtype)


class RemoveSmallBoundingBoxes(Transform):
Expand Down