diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index c9551c9eea4..b9124f280bd 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -190,17 +190,14 @@ def resize_image_tensor( if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT: # uint8 dtype can be included for cpu and cuda input if nearest mode acceptable_dtypes.append(torch.uint8) - elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": + elif ( + interpolation == InterpolationMode.BILINEAR + and image.device.type == "cpu" + and "AVX2" in torch.backends.cpu.get_cpu_capability() + ): # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - if "AVX2" in torch.backends.cpu.get_cpu_capability(): - acceptable_dtypes.append(torch.uint8) - - # TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed - if dtype == torch.uint8 and not ( - image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last) - ): - image = image.contiguous(memory_format=torch.channels_last) + acceptable_dtypes.append(torch.uint8) strides = image.stride() if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]: