diff --git a/test/common_utils.py b/test/common_utils.py index a1d188efdae..4d40b0b18a4 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -420,7 +420,7 @@ def sample_position(values, max_value): dtype = dtype or torch.float32 num_objects = 1 - h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size] + h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size] y = sample_position(h, canvas_size[0]) x = sample_position(w, canvas_size[1]) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9def98dac15..4f9a08b8412 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1,6 +1,4 @@ import itertools -import pathlib -import pickle import random import numpy as np @@ -11,22 +9,11 @@ import torchvision.transforms.v2 as transforms from common_utils import assert_equal, cpu_and_cuda -from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import to_pil_image -from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw -from transforms_v2_legacy_utils import ( - make_bounding_boxes, - make_detection_mask, - make_image, - make_images, - make_multiple_bounding_boxes, - make_segmentation_mask, - make_video, - make_videos, -) +from torchvision.transforms.v2._utils import is_pure_tensor +from transforms_v2_legacy_utils import make_bounding_boxes, make_detection_mask, make_image, make_images, make_videos def make_vanilla_tensor_images(*args, **kwargs): @@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs): yield to_pil_image(image) -def make_vanilla_tensor_bounding_boxes(*args, **kwargs): - for bounding_boxes in make_multiple_bounding_boxes(*args, **kwargs): - yield bounding_boxes.data - - def parametrize(transforms_with_inputs): return pytest.mark.parametrize( ("transform", "input"), @@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs): ) -def auto_augment_adapter(transform, input, device): - adapted_input = {} - image_or_video_found = False - for key, value in input.items(): - if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): - # AA transforms don't support bounding boxes or masks - continue - elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)): - if image_or_video_found: - # AA transforms only support a single image or video - continue - image_or_video_found = True - adapted_input[key] = value - return adapted_input - - -def linear_transformation_adapter(transform, input, device): - flat_inputs = list(input.values()) - c, h, w = query_chw( - [ - item - for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs)) - if needs_transform - ] - ) - num_elements = c * h * w - transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device) - transform.mean_vector = torch.randn((num_elements,), device=device) - return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)} - - -def normalize_adapter(transform, input, device): - adapted_input = {} - for key, value in input.items(): - if isinstance(value, PIL.Image.Image): - # normalize doesn't support PIL images - continue - elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)): - # normalize doesn't support integer images - value = F.to_dtype(value, torch.float32, scale=True) - adapted_input[key] = value - return adapted_input - - -class TestSmoke: - @pytest.mark.parametrize( - ("transform", "adapter"), - [ - (transforms.RandomErasing(p=1.0), None), - (transforms.AugMix(), auto_augment_adapter), - (transforms.AutoAugment(), auto_augment_adapter), - (transforms.RandAugment(), auto_augment_adapter), - (transforms.TrivialAugmentWide(), auto_augment_adapter), - (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), - (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), - (transforms.RandomAutocontrast(p=1.0), None), - (transforms.RandomEqualize(p=1.0), None), - (transforms.RandomInvert(p=1.0), None), - (transforms.RandomChannelPermutation(), None), - (transforms.RandomPosterize(bits=4, p=1.0), None), - (transforms.RandomSolarize(threshold=0.5, p=1.0), None), - (transforms.CenterCrop([16, 16]), None), - (transforms.ElasticTransform(sigma=1.0), None), - (transforms.Pad(4), None), - (transforms.RandomAffine(degrees=30.0), None), - (transforms.RandomCrop([16, 16], pad_if_needed=True), None), - (transforms.RandomHorizontalFlip(p=1.0), None), - (transforms.RandomPerspective(p=1.0), None), - (transforms.RandomResize(min_size=10, max_size=20, antialias=True), None), - (transforms.RandomResizedCrop([16, 16], antialias=True), None), - (transforms.RandomRotation(degrees=30), None), - (transforms.RandomShortestSize(min_size=10, antialias=True), None), - (transforms.RandomVerticalFlip(p=1.0), None), - (transforms.Resize([16, 16], antialias=True), None), - (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), - (transforms.ClampBoundingBoxes(), None), - (transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None), - (transforms.ConvertImageDtype(), None), - (transforms.GaussianBlur(kernel_size=3), None), - ( - transforms.LinearTransformation( - # These are just dummy values that will be filled by the adapter. We can't define them upfront, - # because for we neither know the spatial size nor the device at this point - transformation_matrix=torch.empty((1, 1)), - mean_vector=torch.empty((1,)), - ), - linear_transformation_adapter, - ), - (transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter), - (transforms.ToDtype(torch.float64), None), - (transforms.UniformTemporalSubsample(num_samples=2), None), - ], - ids=lambda transform: type(transform).__name__, - ) - @pytest.mark.parametrize("container_type", [dict, list, tuple]) - @pytest.mark.parametrize( - "image_or_video", - [ - make_image(), - make_video(), - next(make_pil_images(color_spaces=["RGB"])), - next(make_vanilla_tensor_images()), - ], - ) - @pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device): - transform = de_serialize(transform) - - canvas_size = F.get_size(image_or_video) - input = dict( - image_or_video=image_or_video, - image_tv_tensor=make_image(size=canvas_size), - video_tv_tensor=make_video(size=canvas_size), - image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])), - bounding_boxes_xyxy=make_bounding_boxes( - format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,) - ), - bounding_boxes_xywh=make_bounding_boxes( - format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,) - ), - bounding_boxes_cxcywh=make_bounding_boxes( - format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,) - ), - bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes( - [ - [0, 0, 0, 0], # no height or width - [0, 0, 0, 1], # no height - [0, 0, 1, 0], # no width - [2, 0, 1, 1], # x1 > x2, y1 < y2 - [0, 2, 1, 1], # x1 < x2, y1 > y2 - [2, 2, 1, 1], # x1 > x2, y1 > y2 - ], - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=canvas_size, - ), - bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes( - [ - [0, 0, 0, 0], # no height or width - [0, 0, 0, 1], # no height - [0, 0, 1, 0], # no width - [0, 0, 1, -1], # negative height - [0, 0, -1, 1], # negative width - [0, 0, -1, -1], # negative height and width - ], - format=tv_tensors.BoundingBoxFormat.XYWH, - canvas_size=canvas_size, - ), - bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes( - [ - [0, 0, 0, 0], # no height or width - [0, 0, 0, 1], # no height - [0, 0, 1, 0], # no width - [0, 0, 1, -1], # negative height - [0, 0, -1, 1], # negative width - [0, 0, -1, -1], # negative height and width - ], - format=tv_tensors.BoundingBoxFormat.CXCYWH, - canvas_size=canvas_size, - ), - detection_mask=make_detection_mask(size=canvas_size), - segmentation_mask=make_segmentation_mask(size=canvas_size), - int=0, - float=0.0, - bool=True, - none=None, - str="str", - path=pathlib.Path.cwd(), - object=object(), - tensor=torch.empty(5), - array=np.empty(5), - ) - if adapter is not None: - input = adapter(transform, input, device) - - if container_type in {tuple, list}: - input = container_type(input.values()) - - input_flat, input_spec = tree_flatten(input) - input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat] - input = tree_unflatten(input_flat, input_spec) - - torch.manual_seed(0) - output = transform(input) - output_flat, output_spec = tree_flatten(output) - - assert output_spec == input_spec - - for output_item, input_item, should_be_transformed in zip( - output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat) - ): - if should_be_transformed: - assert type(output_item) is type(input_item) - else: - assert output_item is input_item - - if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance( - transform, transforms.ConvertBoundingBoxFormat - ): - assert output_item.format == input_item.format - - # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future - # transform that does this), back into a valid one. - # TODO: we should test that against all degenerate boxes above - for format in list(tv_tensors.BoundingBoxFormat): - sample = dict( - boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)), - labels=torch.tensor([3]), - ) - assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) - - @pytest.mark.parametrize( "flat_inputs", itertools.permutations( @@ -543,39 +313,6 @@ def test__get_params(self, min_size, max_size): assert shorter in min_size -class TestLinearTransformation: - def test_assertions(self): - with pytest.raises(ValueError, match="transformation_matrix should be square"): - transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5)) - - with pytest.raises(ValueError, match="mean_vector should have the same length"): - transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5)) - - @pytest.mark.parametrize( - "inpt", - [ - 122 * torch.ones(1, 3, 8, 8), - 122.0 * torch.ones(1, 3, 8, 8), - tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)), - PIL.Image.new("RGB", (8, 8), (122, 122, 122)), - ], - ) - def test__transform(self, inpt): - - v = 121 * torch.ones(3 * 8 * 8) - m = torch.ones(3 * 8 * 8, 3 * 8 * 8) - transform = transforms.LinearTransformation(m, v) - - if isinstance(inpt, PIL.Image.Image): - with pytest.raises(TypeError, match="does not support PIL images"): - transform(inpt) - else: - output = transform(inpt) - assert isinstance(output, torch.Tensor) - assert output.unique() == 3 * 8 * 8 - assert output.dtype == inpt.dtype - - class TestRandomResize: def test__get_params(self): min_size = 3 diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 0d2fed014c1..b4ce189e758 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -72,28 +72,6 @@ def __init__( LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) CONSISTENCY_CONFIGS = [ - *[ - ConsistencyConfig( - v2_transforms.LinearTransformation, - legacy_transforms.LinearTransformation, - [ - ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)), - ], - # Make sure that the product of the height, width and number of channels matches the number of elements in - # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype] - ), - supports_pil=False, - ) - for matrix_dtype, image_dtype in [ - (torch.float32, torch.float32), - (torch.float64, torch.float64), - (torch.float32, torch.uint8), - (torch.float64, torch.float32), - (torch.float32, torch.float64), - ] - ], ConsistencyConfig( v2_transforms.ToPILImage, legacy_transforms.ToPILImage, diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 99bee49d4b7..0ac51b114f9 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -6,6 +6,7 @@ import math import pickle import re +from copy import deepcopy from pathlib import Path from unittest import mock @@ -37,13 +38,14 @@ from torch import nn from torch.testing import assert_close -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal @@ -261,7 +263,123 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): _script(v1_transform)(input) -def check_transform(transform, input, check_v1_compatibility=True): +def _make_transform_sample(transform, *, image_or_video, adapter): + device = image_or_video.device if isinstance(image_or_video, torch.Tensor) else "cpu" + size = F.get_size(image_or_video) + input = dict( + image_or_video=image_or_video, + image_tv_tensor=make_image(size, device=device), + video_tv_tensor=make_video(size, device=device), + image_pil=make_image_pil(size), + bounding_boxes_xyxy=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYXY, device=device), + bounding_boxes_xywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYWH, device=device), + bounding_boxes_cxcywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.CXCYWH, device=device), + bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [2, 0, 1, 1], # x1 > x2, y1 < y2 + [0, 2, 1, 1], # x1 < x2, y1 > y2 + [2, 2, 1, 1], # x1 > x2, y1 > y2 + ], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=size, + device=device, + ), + bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=tv_tensors.BoundingBoxFormat.XYWH, + canvas_size=size, + device=device, + ), + bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=tv_tensors.BoundingBoxFormat.CXCYWH, + canvas_size=size, + device=device, + ), + detection_mask=make_detection_mask(size, device=device), + segmentation_mask=make_segmentation_mask(size, device=device), + int=0, + float=0.0, + bool=True, + none=None, + str="str", + path=Path.cwd(), + object=object(), + tensor=torch.empty(5), + array=np.empty(5), + ) + if adapter is not None: + input = adapter(transform, input, device) + return input + + +def _check_transform_sample_input_smoke(transform, input, *, adapter): + # This is a bunch of input / output convention checks, using a big sample with different parts as input. + + if not check_type(input, (is_pure_tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video)): + return + + sample = _make_transform_sample( + # adapter might change transform inplace + transform=transform if adapter is None else deepcopy(transform), + image_or_video=input, + adapter=adapter, + ) + for container_type in [dict, list, tuple]: + if container_type is dict: + input = sample + else: + input = container_type(sample.values()) + + input_flat, input_spec = tree_flatten(input) + + with freeze_rng_state(): + torch.manual_seed(0) + output = transform(input) + output_flat, output_spec = tree_flatten(output) + + assert output_spec == input_spec + + for output_item, input_item, should_be_transformed in zip( + output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat) + ): + if should_be_transformed: + assert type(output_item) is type(input_item) + else: + assert output_item is input_item + + # Enforce that the transform does not turn a degenerate bounding box, e.g. marked by RandomIoUCrop (or any other + # future transform that does this), back into a valid one. + for degenerate_bounding_boxes in ( + bounding_box + for name, bounding_box in sample.items() + if "degenerate" in name and isinstance(bounding_box, tv_tensors.BoundingBoxes) + ): + sample = dict( + boxes=degenerate_bounding_boxes, + labels=torch.randint(10, (degenerate_bounding_boxes.shape[0],), device=degenerate_bounding_boxes.device), + ) + assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) + + +def check_transform(transform, input, check_v1_compatibility=True, check_sample_input=True): pickle.loads(pickle.dumps(transform)) output = transform(input) @@ -270,6 +388,11 @@ def check_transform(transform, input, check_v1_compatibility=True): if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat): assert output.format == input.format + if check_sample_input: + _check_transform_sample_input_smoke( + transform, input, adapter=check_sample_input if callable(check_sample_input) else None + ) + if check_v1_compatibility: _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) @@ -1758,7 +1881,7 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a input = make_input(dtype=input_dtype, device=device) if as_dict: output_dtype = {type(input): output_dtype} - check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input) + check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict) def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): input_dtype = image.dtype @@ -2559,9 +2682,13 @@ def test_functional_image_correctness(self, kwargs): def test_transform(self, param, value, make_input): input = make_input(self.INPUT_SIZE) + check_sample_input = True if param == "fill": - if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)): - pytest.skip("F.pad_mask doesn't support non-scalar fill.") + if isinstance(value, (tuple, list)): + if isinstance(input, tv_tensors.Mask): + pytest.skip("F.pad_mask doesn't support non-scalar fill.") + else: + check_sample_input = False kwargs = dict( # 1. size is required @@ -2576,6 +2703,7 @@ def test_transform(self, param, value, make_input): transforms.RandomCrop(**kwargs, pad_if_needed=True), input, check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), + check_sample_input=check_sample_input, ) @pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)]) @@ -2761,9 +2889,13 @@ def test_functional_signature(self, kernel, input_type): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): input = make_input(device=device) - check_transform( - transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image) - ) + + with pytest.warns(UserWarning, match="currently passing through inputs of type"): + check_transform( + transforms.RandomErasing(p=1), + input, + check_v1_compatibility=not isinstance(input, PIL.Image.Image), + ) def _reference_erase_image(self, image, *, i, j, h, w, v): mask = torch.zeros_like(image, dtype=torch.bool) @@ -2835,18 +2967,6 @@ def test_transform_errors(self): with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): transform._get_params([make_image()]) - @pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask]) - def test_transform_passthrough(self, make_input): - transform = transforms.RandomErasing(p=1) - - input = make_input(self.INPUT_SIZE) - - with pytest.warns(UserWarning, match="currently passing through inputs of type"): - # RandomErasing requires an image or video to be present - _, output = transform(make_image(self.INPUT_SIZE), input) - - assert output is input - class TestGaussianBlur: @pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]]) @@ -3063,6 +3183,21 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio else: assert_close(actual, expected, rtol=0, atol=1) + def _sample_input_adapter(self, transform, input, device): + adapted_input = {} + image_or_video_found = False + for key, value in input.items(): + if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + # AA transforms don't support bounding boxes or masks + continue + elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)): + if image_or_video_found: + # AA transforms only support a single image or video + continue + image_or_video_found = True + adapted_input[key] = value + return adapted_input + @pytest.mark.parametrize( "transform", [transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()], @@ -3087,7 +3222,9 @@ def test_transform_smoke(self, transform, make_input, dtype, device): # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks # here and only check if we can script the v2 transform and subsequently call the result. - check_transform(transform, input, check_v1_compatibility=False) + check_transform( + transform, input, check_v1_compatibility=False, check_sample_input=self._sample_input_adapter + ) if type(input) is torch.Tensor and dtype is torch.uint8: _script(transform)(input) @@ -4014,9 +4151,25 @@ def test_functional_error(self): with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"): F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std) + def _sample_input_adapter(self, transform, input, device): + adapted_input = {} + for key, value in input.items(): + if isinstance(value, PIL.Image.Image): + # normalize doesn't support PIL images + continue + elif check_type(value, (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)): + # normalize doesn't support integer images + value = F.to_dtype(value, torch.float32, scale=True) + adapted_input[key] = value + return adapted_input + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) def test_transform(self, make_input): - check_transform(transforms.Normalize(mean=self.MEAN, std=self.STD), make_input(dtype=torch.float32)) + check_transform( + transforms.Normalize(mean=self.MEAN, std=self.STD), + make_input(dtype=torch.float32), + check_sample_input=self._sample_input_adapter, + ) def _reference_normalize_image(self, image, *, mean, std): image = image.numpy() @@ -4543,7 +4696,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) def test_transform(self, make_input, transform_cls): - check_transform(self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), make_input(self.INPUT_SIZE)) + check_transform( + self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), + make_input(self.INPUT_SIZE), + check_sample_input=False, + ) @pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask]) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) @@ -4826,3 +4983,66 @@ def test__get_params(self): assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max) assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max) + + +class TestLinearTransform: + def _make_matrix_and_vector(self, input, *, device=None): + device = device or input.device + numel = math.prod(F.get_dimensions(input)) + transformation_matrix = torch.randn((numel, numel), device=device) + mean_vector = torch.randn((numel,), device=device) + return transformation_matrix, mean_vector + + def _sample_input_adapter(self, transform, input, device): + return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)} + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform(self, make_input, dtype, device): + input = make_input(dtype=dtype, device=device) + check_transform( + transforms.LinearTransformation(*self._make_matrix_and_vector(input)), + input, + check_sample_input=self._sample_input_adapter, + ) + + def test_transform_error(self): + with pytest.raises(ValueError, match="transformation_matrix should be square"): + transforms.LinearTransformation(transformation_matrix=torch.rand(2, 3), mean_vector=torch.rand(2)) + + with pytest.raises(ValueError, match="mean_vector should have the same length"): + transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(1)) + + for matrix_dtype, vector_dtype in [(torch.float32, torch.float64), (torch.float64, torch.float32)]: + with pytest.raises(ValueError, match="Input tensors should have the same dtype"): + transforms.LinearTransformation( + transformation_matrix=torch.rand(2, 2, dtype=matrix_dtype), + mean_vector=torch.rand(2, dtype=vector_dtype), + ) + + image = make_image() + transform = transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(2)) + with pytest.raises(ValueError, match="Input tensor and transformation matrix have incompatible shape"): + transform(image) + + transform = transforms.LinearTransformation(*self._make_matrix_and_vector(image)) + with pytest.raises(TypeError, match="does not support PIL images"): + transform(F.to_pil_image(image)) + + @needs_cuda + def test_transform_error_cuda(self): + for matrix_device, vector_device in [("cuda", "cpu"), ("cpu", "cuda")]: + with pytest.raises(ValueError, match="Input tensors should be on the same device"): + transforms.LinearTransformation( + transformation_matrix=torch.rand(2, 2, device=matrix_device), + mean_vector=torch.rand(2, device=vector_device), + ) + + for input_device, param_device in [("cuda", "cpu"), ("cpu", "cuda")]: + input = make_image(device=input_device) + transform = transforms.LinearTransformation(*self._make_matrix_and_vector(input, device=param_device)) + with pytest.raises( + ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector" + ): + transform(input)