diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 65793dc45df..7b49ce8e85e 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, cast, List, Optional, Tuple, Union import torch from torchvision.transforms import InterpolationMode @@ -32,6 +32,10 @@ def wrap_like( ) -> Mask: return cls._wrap(tensor) + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], tuple(self.shape[-2:])) + def horizontal_flip(self) -> Mask: output = self._F.horizontal_flip_mask(self) return Mask.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7e28d9d6cc6..6ef9edba354 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype import features from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform -from torchvision.prototype.transforms.functional._meta import get_chw +from torchvision.prototype.transforms.functional._meta import get_spatial_size from ._utils import _isinstance, _setup_fill_arg @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image_or_video = self._extract_image_or_video(sample) - _, height, width = get_chw(image_or_video) + height, width = get_spatial_size(image_or_video) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -349,7 +349,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image_or_video = self._extract_image_or_video(sample) - _, height, width = get_chw(image_or_video) + height, width = get_spatial_size(image_or_video) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -403,7 +403,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image_or_video = self._extract_image_or_video(sample) - _, height, width = get_chw(image_or_video) + height, width = get_spatial_size(image_or_video) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -473,7 +473,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, orig_image_or_video = self._extract_image_or_video(sample) - _, height, width = get_chw(orig_image_or_video) + height, width = get_spatial_size(orig_image_or_video) if isinstance(orig_image_or_video, torch.Tensor): image_or_video = orig_image_or_video diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index a76891a348a..db1ff4b7b6f 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -10,7 +10,7 @@ from torchvision.prototype import features from torchvision.prototype.features._feature import FillType -from torchvision.prototype.transforms.functional._meta import get_chw +from torchvision.prototype.transforms.functional._meta import get_dimensions from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from typing_extensions import Literal @@ -80,7 +80,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: def query_chw(sample: Any) -> Tuple[int, int, int]: flat_sample, _ = tree_flatten(sample) chws = { - get_chw(item) + tuple(get_dimensions(item)) for item in flat_sample if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item) } @@ -88,7 +88,8 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: raise TypeError("No image or video was found in the sample") elif len(chws) > 1: raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") - return chws.pop() + c, h, w = chws.pop() + return c, h, w def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index cb801df73c7..1e918cc3492 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -8,9 +8,15 @@ convert_color_space_image_pil, convert_color_space_video, convert_color_space, + get_dimensions_image_tensor, + get_dimensions_image_pil, get_dimensions, get_image_num_channels, + get_num_channels_image_tensor, + get_num_channels_image_pil, get_num_channels, + get_spatial_size_image_tensor, + get_spatial_size_image_pil, get_spatial_size, ) # usort: skip diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c63fe5b41b1..670b2cb87b8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -21,7 +21,12 @@ interpolate, ) -from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor +from ._meta import ( + convert_format_bounding_box, + get_dimensions_image_tensor, + get_spatial_size_image_pil, + get_spatial_size_image_tensor, +) horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -323,7 +328,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - _, height, width = get_dimensions_image_pil(image) + height, width = get_spatial_size_image_pil(image) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = pad_image_tensor(image, padding_ltrb, fill=0) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_width == image_width and crop_height == image_height: return image @@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor @torch.jit.unused def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = pad_image_pil(image, padding_ltrb, fill=0) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_width == image_width and crop_height == image_height: return image @@ -1365,7 +1370,7 @@ def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: crop_height, crop_width = _parse_five_crop_size(size) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" @@ -1385,7 +1390,7 @@ def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: crop_height, crop_width = _parse_five_crop_size(size) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 1e53edf3940..e24b68c9fd6 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -6,38 +6,37 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT + get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions -# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]: +def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: if isinstance(image, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) ): - channels, height, width = get_dimensions_image_tensor(image) + return get_dimensions_image_tensor(image) elif isinstance(image, (features.Image, features.Video)): channels = image.num_channels height, width = image.image_size - else: # isinstance(image, PIL.Image.Image) - channels, height, width = get_dimensions_image_pil(image) - return channels, height, width - - -# The three functions below are here for BC. Whether we want to have two different kernels and how they and the -# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491 -# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay. -# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue -# detailed above. + return [channels, height, width] + else: + return get_dimensions_image_pil(image) -def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: - return list(get_chw(image)) +get_num_channels_image_tensor = _FT.get_image_num_channels +get_num_channels_image_pil = _FP.get_image_num_channels def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: - num_channels, *_ = get_chw(image) - return num_channels + if isinstance(image, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) + ): + return _FT.get_image_num_channels(image) + elif isinstance(image, (features.Image, features.Video)): + return image.num_channels + else: + return _FP.get_image_num_channels(image) # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without @@ -45,9 +44,28 @@ def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]: - _, *size = get_chw(image) - return size +def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: + width, height = _FT.get_image_size(image) + return [height, width] + + +@torch.jit.unused +def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: + width, height = _FP.get_image_size(image) + return [height, width] + + +def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + return get_spatial_size_image_tensor(inpt) + elif isinstance(inpt, features._Feature): + image_size = getattr(inpt, "image_size", None) + if image_size is not None: + return list(image_size) + else: + raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.") + else: + return get_spatial_size_image_pil(inpt) def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: