diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index ec68fd72a5b..4e30dfab2cc 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -7,7 +7,9 @@ import zipfile import pytest +import torch import torchvision.datasets.utils as utils +from common_utils import assert_equal from torch._utils_internal import get_file_path_2 from torchvision.datasets.folder import make_dataset from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS @@ -215,6 +217,24 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") + @pytest.mark.parametrize( + ("dtype", "actual_hex", "expected_hex"), + [ + (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"), + (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"), + (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"), + (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"), + ], + ) + def test_flip_byte_order(self, dtype, actual_hex, expected_hex): + def to_tensor(hex): + return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) + + assert_equal( + utils._flip_byte_order(to_tensor(actual_hex)), + to_tensor(expected_hex), + ) + @pytest.mark.parametrize( ("kwargs", "expected_error_msg"), diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index a5b23cfe071..6953d1fc5c2 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -12,7 +12,7 @@ import torch from PIL import Image -from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg from .vision import VisionDataset @@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - num_bytes_per_value = torch.iinfo(torch_type).bits // 8 - # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, - # we need to reverse the bytes before we can read them with torch.frombuffer(). - needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) - if needs_byte_reversal: - parsed = parsed.flip(0) + + # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case + # that is little endian and the dtype has more than one byte, we need to flip them. + if sys.byteorder == "little" and parsed.element_size() > 1: + parsed = _flip_byte_order(parsed) assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b8aaff3d773..fb9de2e445d 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -520,3 +520,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: data = np.flip(data, axis=1) # flip on h dimension data = data[:slice_channels, :, :] return data.astype(np.float32) + + +def _flip_byte_order(t: torch.Tensor) -> torch.Tensor: + return ( + t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype) + )