Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,14 @@ def resize_image_tensor(
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
elif (
interpolation == InterpolationMode.BILINEAR
and image.device.type == "cpu"
and "AVX2" in torch.backends.cpu.get_cpu_capability()
):
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
if "AVX2" in torch.backends.cpu.get_cpu_capability():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could have been in the original elif already or am I missing something?

Copy link
Member Author

@NicolasHug NicolasHug May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes there's nothing before or after that block so it's logically the same

acceptable_dtypes.append(torch.uint8)

# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
if dtype == torch.uint8 and not (
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
):
image = image.contiguous(memory_format=torch.channels_last)
acceptable_dtypes.append(torch.uint8)

strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
Expand Down