Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,43 +1013,3 @@ def test_correctness_uniform_temporal_subsample(device):

out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")

output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
assert input.ndim == 3, error_msg_fn
input_stride = input.stride()
output_stride = output.stride()
# Here we check output memory format according to the input:
# if input_stride is (..., 1) then input is most likely channels first and thus
# output strides should match channels first strides (H * W, H, 1)
# if input_stride is (1, ...) then input is most likely channels last and thus
# output strides should match channels last strides (1, W * C, C)
if input_stride[-1] == 1:
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
assert expected_stride == output_stride, error_msg_fn("")
elif input_stride[0] == 1:
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
assert expected_stride == output_stride, error_msg_fn("")
else:
assert False, error_msg_fn("")


def test_resize_float16_no_rounding():
# Make sure Resize() doesn't round float16 images
# Non-regression test for https://github.com/pytorch/vision/issues/7667

img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
out = F.resize(img, size=(10, 10))
assert out.dtype == torch.float16
assert (out.round() - out).sum() > 0
60 changes: 60 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,66 @@ def test_no_regression_5405(self, make_input):

assert max(F.get_size(output)) == max_size

def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs):
# torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming
# from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the
# layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create
# the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels
# last although PyTorch doesn't recognizes it as such.
emulate_channels_last = memory_format is torch.channels_last and len(batch_dims) != 1

image = make_image(
*args,
batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims,
memory_format=memory_format,
**kwargs,
)

if emulate_channels_last:
image = datapoints.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)

return image

def _check_stride(self, image, *, memory_format):
C, H, W = F.get_dimensions(image)
if memory_format is torch.contiguous_format:
expected_stride = (H * W, W, 1)
elif memory_format is torch.channels_last:
expected_stride = (1, W * C, C)
else:
raise ValueError(f"Unknown memory_format: {memory_format}")

assert image.stride() == expected_stride

# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_memory_format_consistency(self, interpolation, antialias, memory_format, dtype, device):
size = self.OUTPUT_SIZES[0]

input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format)

# Smoke test to make sure we aren't starting with wrong assumptions
self._check_stride(input, memory_format=memory_format)

output = F.resize_image(input, size=size, interpolation=interpolation, antialias=antialias)

self._check_stride(output, memory_format=memory_format)

def test_float16_no_rounding(self):
# Make sure Resize() doesn't round float16 images
# Non-regression test for https://github.com/pytorch/vision/issues/7667

input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16)
output = F.resize_image(input, size=self.OUTPUT_SIZES[0])

assert output.dtype is torch.float16
assert (output.round() - output).abs().sum() > 0


class TestHorizontalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
Expand Down