From 2521f3985372eb88966cecfe7985c1022b8c78d6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 6 Jul 2023 10:51:44 +0200 Subject: [PATCH 1/6] port remaining resize tests --- test/common_utils.py | 15 +++++++-- test/test_transforms_v2_functional.py | 40 ------------------------ test/test_transforms_v2_refactored.py | 45 +++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 42 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index c815786b586..7b9ccd5033f 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,6 +1,7 @@ import contextlib import functools import itertools +import math import os import pathlib import random @@ -380,14 +381,24 @@ def make_image( num_channels = NUM_CHANNELS_MAP[color_space] dtype = dtype or torch.uint8 max_value = get_max_value(dtype) + + shape = make_tensor_shape = (*batch_dims, num_channels, *size) + # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming from + # PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the layout of + # the data in memory is channels last. To emulate this when a 3D input is requested here, we create the image as 4D + # and create a view with the right shape afterwards. With this the layout in memory is channels last although + # PyTorch doesn't recognizes it as such. + if memory_format is torch.channels_last and len(batch_dims) != 1: + make_tensor_shape = (math.prod(shape[:-3]), *shape[-3:]) + data = torch.testing.make_tensor( - (*batch_dims, num_channels, *size), + make_tensor_shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format, - ) + ).view(shape) if color_space in {"GRAY_ALPHA", "RGBA"}: data[..., -1, :, :] = max_value diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 15af5a7a9ed..e7e5253c7ee 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1015,43 +1015,3 @@ def test_correctness_uniform_temporal_subsample(device): out_video = F.uniform_temporal_subsample(video, 8) assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9] - - -# TODO: We can remove this test and related torchvision workaround -# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 -@make_info_args_kwargs_parametrization( - [info for info in KERNEL_INFOS if info.kernel is F.resize_image], - args_kwargs_fn=lambda info: info.reference_inputs_fn(), -) -def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): - (input, *other_args), kwargs = args_kwargs.load("cpu") - - output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) - - error_msg_fn = parametrized_error_message(input, *other_args, **kwargs) - assert input.ndim == 3, error_msg_fn - input_stride = input.stride() - output_stride = output.stride() - # Here we check output memory format according to the input: - # if input_stride is (..., 1) then input is most likely channels first and thus - # output strides should match channels first strides (H * W, H, 1) - # if input_stride is (1, ...) then input is most likely channels last and thus - # output strides should match channels last strides (1, W * C, C) - if input_stride[-1] == 1: - expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) - assert expected_stride == output_stride, error_msg_fn("") - elif input_stride[0] == 1: - expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0]) - assert expected_stride == output_stride, error_msg_fn("") - else: - assert False, error_msg_fn("") - - -def test_resize_float16_no_rounding(): - # Make sure Resize() doesn't round float16 images - # Non-regression test for https://github.com/pytorch/vision/issues/7667 - - img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16) - out = F.resize(img, size=(10, 10)) - assert out.dtype == torch.float16 - assert (out.round() - out).sum() > 0 diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index f57736e5abd..c3ef28fd72e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -732,6 +732,51 @@ def test_no_regression_5405(self, make_input): assert max(F.get_size(output)) == max_size + def _check_stride(self, image, *, memory_format): + C, H, W = F.get_dimensions(image) + if memory_format is torch.contiguous_format: + expected_stride = (H * W, W, 1) + elif memory_format is torch.channels_last: + expected_stride = (1, W * C, C) + else: + raise ValueError(f"Unknown memory_format: {memory_format}") + + assert image.stride() == expected_stride + + # TODO: We can remove this test and related torchvision workaround + # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 + @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) + @pytest.mark.parametrize("use_max_size", [True, False]) + @pytest.mark.parametrize("antialias", [True, False]) + @pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image_memory_format_consistency( + self, interpolation, use_max_size, antialias, memory_format, dtype, device + ): + size = self.OUTPUT_SIZES[0] + if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): + return + + input = make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format) + + # Smoke test to make sure we aren't starting with wrong assumptions + self._check_stride(input, memory_format=memory_format) + + output = F.resize_image(input, size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias) + + self._check_stride(output, memory_format=memory_format) + + def test_no_regression_7667(self): + # Checks that float16 images are not rounded + # See https://github.com/pytorch/vision/issues/7667 + + input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16) + output = F.resize_image(input, size=self.OUTPUT_SIZES[0]) + + assert output.dtype is torch.float16 + assert (output.round() - output).abs().sum() > 0 + class TestHorizontalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) From 194a758cd1a249d802f2e9f44431cfd322ed1154 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Aug 2023 10:16:56 +0200 Subject: [PATCH 2/6] revert name --- test/test_transforms_v2_refactored.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index ff503943722..4c9e195090a 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -770,8 +770,7 @@ def test_kernel_image_memory_format_consistency( self._check_stride(output, memory_format=memory_format) - def test_no_regression_7667(self): - # Checks that float16 images are not rounded + def test_float16_no_rounding(self): # See https://github.com/pytorch/vision/issues/7667 input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16) From b53cd8ea0f732d5e1176035e60ea5a33757e3125 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Aug 2023 10:30:59 +0200 Subject: [PATCH 3/6] move channels_last handling to resize test --- test/common_utils.py | 14 ++------------ test/test_transforms_v2_refactored.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 4479093d750..45b7453ada4 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,7 +1,6 @@ import contextlib import functools import itertools -import math import os import pathlib import random @@ -382,23 +381,14 @@ def make_image( dtype = dtype or torch.uint8 max_value = get_max_value(dtype) - shape = make_tensor_shape = (*batch_dims, num_channels, *size) - # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming from - # PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the layout of - # the data in memory is channels last. To emulate this when a 3D input is requested here, we create the image as 4D - # and create a view with the right shape afterwards. With this the layout in memory is channels last although - # PyTorch doesn't recognizes it as such. - if memory_format is torch.channels_last and len(batch_dims) != 1: - make_tensor_shape = (math.prod(shape[:-3]), *shape[-3:]) - data = torch.testing.make_tensor( - make_tensor_shape, + (*batch_dims, num_channels, *size), low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format, - ).view(shape) + ) if color_space in {"GRAY_ALPHA", "RGBA"}: data[..., -1, :, :] = max_value diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 4c9e195090a..82764d12e4d 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -746,6 +746,26 @@ def _check_stride(self, image, *, memory_format): assert image.stride() == expected_stride + def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs): + # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming + # from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the + # layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create + # the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels + # last although PyTorch doesn't recognizes it as such. + emulate_channels_last = memory_format is torch.channels_last or len(batch_dims) != 1 + + image = make_image( + *args, + batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims, + memory_format=memory_format, + **kwargs, + ) + + if emulate_channels_last: + image = datapoints.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image) + + return image + # TODO: We can remove this test and related torchvision workaround # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) @@ -761,7 +781,7 @@ def test_kernel_image_memory_format_consistency( if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return - input = make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format) + input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format) # Smoke test to make sure we aren't starting with wrong assumptions self._check_stride(input, memory_format=memory_format) From 04c8642e0d57d3db0921c35c5467306f7dd3ce4f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Aug 2023 10:31:51 +0200 Subject: [PATCH 4/6] cleanup --- test/test_transforms_v2_refactored.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 82764d12e4d..9896c6a511c 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -791,7 +791,8 @@ def test_kernel_image_memory_format_consistency( self._check_stride(output, memory_format=memory_format) def test_float16_no_rounding(self): - # See https://github.com/pytorch/vision/issues/7667 + # Make sure Resize() doesn't round float16 images + # Non-regression test for https://github.com/pytorch/vision/issues/7667 input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16) output = F.resize_image(input, size=self.OUTPUT_SIZES[0]) From 1f74dff49ff484c9d7fea1d1704fe2304c3c5205 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Aug 2023 10:32:54 +0200 Subject: [PATCH 5/6] cleanup --- test/common_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/common_utils.py b/test/common_utils.py index 45b7453ada4..61f06994801 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -380,7 +380,6 @@ def make_image( num_channels = NUM_CHANNELS_MAP[color_space] dtype = dtype or torch.uint8 max_value = get_max_value(dtype) - data = torch.testing.make_tensor( (*batch_dims, num_channels, *size), low=0, From 6ef8fce84a04d4c7441cb0524f313958d2ecc041 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Aug 2023 10:15:54 +0200 Subject: [PATCH 6/6] more cleanup --- test/test_transforms_v2_refactored.py | 33 ++++++++++++--------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 9896c6a511c..0b9024c946b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -735,24 +735,13 @@ def test_no_regression_5405(self, make_input): assert max(F.get_size(output)) == max_size - def _check_stride(self, image, *, memory_format): - C, H, W = F.get_dimensions(image) - if memory_format is torch.contiguous_format: - expected_stride = (H * W, W, 1) - elif memory_format is torch.channels_last: - expected_stride = (1, W * C, C) - else: - raise ValueError(f"Unknown memory_format: {memory_format}") - - assert image.stride() == expected_stride - def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs): # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming # from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the # layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create # the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels # last although PyTorch doesn't recognizes it as such. - emulate_channels_last = memory_format is torch.channels_last or len(batch_dims) != 1 + emulate_channels_last = memory_format is torch.channels_last and len(batch_dims) != 1 image = make_image( *args, @@ -766,27 +755,33 @@ def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_forma return image + def _check_stride(self, image, *, memory_format): + C, H, W = F.get_dimensions(image) + if memory_format is torch.contiguous_format: + expected_stride = (H * W, W, 1) + elif memory_format is torch.channels_last: + expected_stride = (1, W * C, C) + else: + raise ValueError(f"Unknown memory_format: {memory_format}") + + assert image.stride() == expected_stride + # TODO: We can remove this test and related torchvision workaround # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) - @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("antialias", [True, False]) @pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last]) @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_memory_format_consistency( - self, interpolation, use_max_size, antialias, memory_format, dtype, device - ): + def test_kernel_image_memory_format_consistency(self, interpolation, antialias, memory_format, dtype, device): size = self.OUTPUT_SIZES[0] - if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): - return input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format) # Smoke test to make sure we aren't starting with wrong assumptions self._check_stride(input, memory_format=memory_format) - output = F.resize_image(input, size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias) + output = F.resize_image(input, size=size, interpolation=interpolation, antialias=antialias) self._check_stride(output, memory_format=memory_format)