From 3586b535e69559a4a75113daf650bc1fb2bc99a0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 11:45:59 +0100 Subject: [PATCH 1/9] Rewrite `get_dimensions`, `get_num_channels` and `get_spatial_size` --- .../transforms/functional/__init__.py | 6 +++ .../prototype/transforms/functional/_meta.py | 41 ++++++++++++++++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index f081d101dff..1e5753f2cfb 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,9 +7,15 @@ convert_color_space_image_tensor, convert_color_space_image_pil, convert_color_space, + get_dimensions_image_tensor, + get_dimensions_image_pil, get_dimensions, get_image_num_channels, + get_num_channels_image_tensor, + get_num_channels_image_pil, get_num_channels, + get_spatial_size_image_tensor, + get_spatial_size_image_pil, get_spatial_size, ) # usort: skip diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 90cfffcf276..d3c62325102 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -30,12 +30,28 @@ def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: def get_dimensions(image: features.ImageTypeJIT) -> List[int]: - return list(get_chw(image)) + if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): + return get_dimensions_image_tensor(image) + elif isinstance(image, features.Image): + channels = image.num_channels + height, width = image.image_size + return [channels, height, width] + else: + return get_dimensions_image_pil(image) + + +get_num_channels_image_tensor = _FT.get_image_num_channels +get_num_channels_image_pil = _FP.get_image_num_channels def get_num_channels(image: features.ImageTypeJIT) -> int: - num_channels, *_ = get_chw(image) - return num_channels + if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): + channels = _FT.get_image_num_channels(image) + elif isinstance(image, features.Image): + channels = image.num_channels + else: + channels = _FP.get_image_num_channels(image) + return channels # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without @@ -43,9 +59,24 @@ def get_num_channels(image: features.ImageTypeJIT) -> int: get_image_num_channels = get_num_channels +def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: + width, height = _FT.get_image_size(image) + return [height, width] + + +@torch.jit.unused +def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: + width, height = _FP.get_image_size(image) + return [height, width] + + def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: - _, *size = get_chw(image) - return size + if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): + return get_spatial_size_image_tensor(image) + elif isinstance(image, features.Image): + return list(image.image_size) + else: + return get_spatial_size_image_pil(image) def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: From d527fbbc9f94d9072c9f51efed615dae96ee9056 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 11:48:22 +0100 Subject: [PATCH 2/9] Remove `get_chw` --- torchvision/prototype/transforms/_auto_augment.py | 10 +++++----- torchvision/prototype/transforms/_utils.py | 4 ++-- torchvision/prototype/transforms/functional/_meta.py | 12 ------------ 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c98e5c36e4a..9f348cd4972 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype import features from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform -from torchvision.prototype.transforms.functional._meta import get_chw +from torchvision.prototype.transforms.functional._meta import get_spatial_size from ._utils import _isinstance, _setup_fill_arg @@ -277,7 +277,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - _, height, width = get_chw(image) + height, width = get_spatial_size(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -348,7 +348,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - _, height, width = get_chw(image) + height, width = get_spatial_size(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -402,7 +402,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - _, height, width = get_chw(image) + height, width = get_spatial_size(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -472,7 +472,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, orig_image = self._extract_image(sample) - _, height, width = get_chw(orig_image) + height, width = get_spatial_size(orig_image) if isinstance(orig_image, torch.Tensor): image = orig_image diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 219e6e50586..7db4f710481 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -10,7 +10,7 @@ from torchvision.prototype import features from torchvision.prototype.features._feature import FillType -from torchvision.prototype.transforms.functional._meta import get_chw +from torchvision.prototype.transforms.functional._meta import get_dimensions from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from typing_extensions import Literal @@ -80,7 +80,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: def query_chw(sample: Any) -> Tuple[int, int, int]: flat_sample, _ = tree_flatten(sample) chws = { - get_chw(item) + get_dimensions(item) for item in flat_sample if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item) } diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index d3c62325102..725ed802382 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -10,18 +10,6 @@ get_dimensions_image_pil = _FP.get_dimensions -# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: - if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): - channels, height, width = get_dimensions_image_tensor(image) - elif isinstance(image, features.Image): - channels = image.num_channels - height, width = image.image_size - else: # isinstance(image, PIL.Image.Image) - channels, height, width = get_dimensions_image_pil(image) - return channels, height, width - - # The three functions below are here for BC. Whether we want to have two different kernels and how they and the # compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491 # Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay. From c9ba0fd60b7ab464955ec35396347ca26a59dfe5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 11:49:24 +0100 Subject: [PATCH 3/9] Remove comments --- torchvision/prototype/transforms/functional/_meta.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 725ed802382..fbe05473857 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -6,17 +6,11 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT + get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions -# The three functions below are here for BC. Whether we want to have two different kernels and how they and the -# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491 -# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay. -# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue -# detailed above. - - def get_dimensions(image: features.ImageTypeJIT) -> List[int]: if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): return get_dimensions_image_tensor(image) From a059af2f351271e49884656f5bfb4f223e84946c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 12:18:40 +0100 Subject: [PATCH 4/9] Make `get_spatial_size` support non-image input --- torchvision/prototype/features/_mask.py | 7 ++++++- .../prototype/transforms/functional/_meta.py | 18 +++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 9dd614752a6..e5acd4b05fb 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Union +from typing import cast, List, Optional, Tuple, Union import torch from torchvision.transforms import InterpolationMode @@ -9,6 +9,11 @@ class Mask(_Feature): + + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], tuple(self.shape[-2:])) + def horizontal_flip(self) -> Mask: output = self._F.horizontal_flip_mask(self) return Mask.new_like(self, output) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index fbe05473857..f3fbdd73b2b 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import cast, List, Optional, Tuple +from typing import cast, List, Optional, Tuple, Union import PIL.Image import torch @@ -52,13 +52,17 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: return [height, width] -def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: - if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): - return get_spatial_size_image_tensor(image) - elif isinstance(image, features.Image): - return list(image.image_size) +def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + return get_spatial_size_image_tensor(inpt) + elif isinstance(inpt, features._Feature): + image_size = getattr(inpt, "image_size", None) + if image_size is not None: + return list(image_size) + else: + raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.") else: - return get_spatial_size_image_pil(image) + return get_spatial_size_image_pil(inpt) def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: From a10ce7afb4c3618fc1b42e51c01760d5bff66248 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 14:22:53 +0100 Subject: [PATCH 5/9] Reduce the unnecessary use of `get_dimensions*` --- torchvision/prototype/features/_mask.py | 1 - .../transforms/functional/_geometry.py | 21 ++++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index e5acd4b05fb..3d50c2d8101 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -9,7 +9,6 @@ class Mask(_Feature): - @property def image_size(self) -> Tuple[int, int]: return cast(Tuple[int, int], tuple(self.shape[-2:])) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7a291967bfd..21cd2008c3a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -21,7 +21,12 @@ interpolate, ) -from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor +from ._meta import ( + convert_format_bounding_box, + get_dimensions_image_tensor, + get_spatial_size_image_pil, + get_spatial_size_image_tensor, +) horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -305,7 +310,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - _, height, width = get_dimensions_image_pil(image) + height, width = get_spatial_size_image_pil(image) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -1071,13 +1076,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = pad_image_tensor(image, padding_ltrb, fill=0) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_width == image_width and crop_height == image_height: return image @@ -1088,13 +1093,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor @torch.jit.unused def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = pad_image_pil(image, padding_ltrb, fill=0) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_width == image_width and crop_height == image_height: return image @@ -1228,7 +1233,7 @@ def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: crop_height, crop_width = _parse_five_crop_size(size) - _, image_height, image_width = get_dimensions_image_tensor(image) + image_height, image_width = get_spatial_size_image_tensor(image) if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" @@ -1248,7 +1253,7 @@ def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: crop_height, crop_width = _parse_five_crop_size(size) - _, image_height, image_width = get_dimensions_image_pil(image) + image_height, image_width = get_spatial_size_image_pil(image) if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" From 9f964c35caefe973065ae80551f4c301135ff6f0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 14:42:35 +0100 Subject: [PATCH 6/9] Fix linters --- torchvision/prototype/transforms/_utils.py | 4 ++-- torchvision/prototype/transforms/functional/_meta.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7db4f710481..4c17544afc8 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import numbers from collections import defaultdict -from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union import PIL.Image @@ -77,7 +77,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: return bounding_boxes.pop() -def query_chw(sample: Any) -> Tuple[int, int, int]: +def query_chw(sample: Any) -> List[int]: flat_sample, _ = tree_flatten(sample) chws = { get_dimensions(item) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index f3fbdd73b2b..6a157a2056e 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import cast, List, Optional, Tuple, Union +from typing import cast, List, Optional, Tuple import PIL.Image import torch From 77148495c8a2100e3248c94936508b0a19cabe0d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 15:49:10 +0100 Subject: [PATCH 7/9] Fix merge bug --- torchvision/prototype/transforms/_auto_augment.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index add12e77366..ed8bfd3d3ee 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -472,9 +472,8 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - - id, image_or_video = self._extract_image_or_video(sample) - height, width = get_spatial_size(image_or_video) + id, orig_image_or_video = self._extract_image_or_video(sample) + height, width = get_spatial_size(orig_image_or_video) if isinstance(orig_image_or_video, torch.Tensor): image_or_video = orig_image_or_video From f48c4f5e5c22b750d1ae526e3d23b53feee5c228 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 15:53:37 +0100 Subject: [PATCH 8/9] Linter --- torchvision/prototype/transforms/_utils.py | 9 +++++---- torchvision/prototype/transforms/functional/_meta.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 5ceec7e8b9c..db1ff4b7b6f 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import numbers from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union import PIL.Image @@ -77,10 +77,10 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: return bounding_boxes.pop() -def query_chw(sample: Any) -> List[int]: +def query_chw(sample: Any) -> Tuple[int, int, int]: flat_sample, _ = tree_flatten(sample) chws = { - get_dimensions(item) + tuple(get_dimensions(item)) for item in flat_sample if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item) } @@ -88,7 +88,8 @@ def query_chw(sample: Any) -> List[int]: raise TypeError("No image or video was found in the sample") elif len(chws) > 1: raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") - return chws.pop() + c, h, w = chws.pop() + return c, h, w def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index f69711fcc33..e24b68c9fd6 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -68,7 +68,6 @@ def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: return get_spatial_size_image_pil(inpt) - def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: xyxy = xywh.clone() xyxy[..., 2:] += xyxy[..., :2] From 2a6ee3dc3b04f6c3d092c6d3a74e21eec68ad666 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 Oct 2022 16:57:37 +0100 Subject: [PATCH 9/9] Fix linter --- torchvision/prototype/features/_mask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 37c581c89a5..7b49ce8e85e 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Any, cast, List, Optional, Tuple, Union import torch