Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)])
dtype = self.dtype.get(type(inpt))
if dtype is None:
return inpt
return inpt.to(dtype=dtype)


class RemoveSmallBoundingBoxes(Transform):
Expand Down
6 changes: 1 addition & 5 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
features.is_simple_tensor,
features._Feature,
PIL.Image.Image,
)
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)

def __init__(self) -> None:
super().__init__()
Expand Down