From ec2099826ff5d0789059e5506e6b15ed17d2764c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Jan 2023 11:16:48 +0100 Subject: [PATCH 1/4] fix MNIST byte flipping --- torchvision/datasets/mnist.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index fd742544935..1b5c8431c9e 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -519,13 +519,20 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - num_bytes_per_value = torch.iinfo(torch_type).bits // 8 - # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, - # we need to reverse the bytes before we can read them with torch.frombuffer(). - needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) - if needs_byte_reversal: - parsed = parsed.flip(0) + + # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case + # that is little endian and the dtype has more than one byte, we need to flip them. + num_bytes_per_value = parsed.element_size() + if sys.byteorder == "little" and num_bytes_per_value > 1: + parsed = ( + parsed.contiguous() + .view(torch.uint8) + .view(parsed.numel(), num_bytes_per_value) + .flip(1) + .flatten() + .view(torch_type) + ) assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s) From 54c89864a79240967ae7f231b39bb4a20dfda663 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Jan 2023 13:27:33 +0100 Subject: [PATCH 2/4] add test --- test/test_datasets.py | 22 ++++++++++++++++++++++ torchvision/datasets/mnist.py | 18 ++++++++---------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index bd6d1dcb259..d667c210661 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -21,6 +21,7 @@ import pytest import torch import torch.nn.functional as F +from common_utils import assert_equal from torchvision import datasets @@ -1494,6 +1495,27 @@ def test_num_examples_test50k(self): assert len(dataset) == info["num_examples"] - 10000 +@pytest.mark.parametrize( + ("dtype", "actual_hex", "expected_hex"), + [ + (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"), + (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"), + (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"), + (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"), + ], +) +def test_flip_byte_order(dtype, actual_hex, expected_hex): + from torchvision.datasets.mnist import _flip_byte_order + + def to_tensor(hex): + return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) + + assert_equal( + _flip_byte_order(to_tensor(actual_hex)), + to_tensor(expected_hex), + ) + + class MovingMNISTTestCase(datasets_utils.DatasetTestCase): DATASET_CLASS = datasets.MovingMNIST FEATURE_TYPES = (torch.Tensor,) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 1b5c8431c9e..d2c34a031c4 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -503,6 +503,12 @@ def get_int(b: bytes) -> int: } +def _flip_byte_order(t: torch.Tensor) -> torch.Tensor: + return ( + t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype) + ) + + def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. @@ -523,16 +529,8 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case # that is little endian and the dtype has more than one byte, we need to flip them. - num_bytes_per_value = parsed.element_size() - if sys.byteorder == "little" and num_bytes_per_value > 1: - parsed = ( - parsed.contiguous() - .view(torch.uint8) - .view(parsed.numel(), num_bytes_per_value) - .flip(1) - .flatten() - .view(torch_type) - ) + if sys.byteorder == "little" and parsed.element_size() > 1: + parsed = _flip_byte_order(parsed) assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s) From 096f2088b23e737dea31172fb80c99a6c217a963 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Jan 2023 13:30:23 +0100 Subject: [PATCH 3/4] move to utils --- test/test_datasets.py | 22 ---------------------- test/test_datasets_utils.py | 22 ++++++++++++++++++++++ torchvision/datasets/mnist.py | 8 +------- torchvision/datasets/utils.py | 6 ++++++ 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index d667c210661..bd6d1dcb259 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -21,7 +21,6 @@ import pytest import torch import torch.nn.functional as F -from common_utils import assert_equal from torchvision import datasets @@ -1495,27 +1494,6 @@ def test_num_examples_test50k(self): assert len(dataset) == info["num_examples"] - 10000 -@pytest.mark.parametrize( - ("dtype", "actual_hex", "expected_hex"), - [ - (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"), - (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"), - (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"), - (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"), - ], -) -def test_flip_byte_order(dtype, actual_hex, expected_hex): - from torchvision.datasets.mnist import _flip_byte_order - - def to_tensor(hex): - return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) - - assert_equal( - _flip_byte_order(to_tensor(actual_hex)), - to_tensor(expected_hex), - ) - - class MovingMNISTTestCase(datasets_utils.DatasetTestCase): DATASET_CLASS = datasets.MovingMNIST FEATURE_TYPES = (torch.Tensor,) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index ec68fd72a5b..992b0478583 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -7,7 +7,9 @@ import zipfile import pytest +import torch import torchvision.datasets.utils as utils +from common_utils import assert_equal from torch._utils_internal import get_file_path_2 from torchvision.datasets.folder import make_dataset from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS @@ -215,6 +217,26 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") + @pytest.mark.parametrize( + ("dtype", "actual_hex", "expected_hex"), + [ + (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"), + (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"), + (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"), + (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"), + ], + ) + def test_flip_byte_order(self, dtype, actual_hex, expected_hex): + from torchvision.datasets.mnist import _flip_byte_order + + def to_tensor(hex): + return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) + + assert_equal( + _flip_byte_order(to_tensor(actual_hex)), + to_tensor(expected_hex), + ) + @pytest.mark.parametrize( ("kwargs", "expected_error_msg"), diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index d2c34a031c4..4fa203462fb 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -12,7 +12,7 @@ import torch from PIL import Image -from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg from .vision import VisionDataset @@ -503,12 +503,6 @@ def get_int(b: bytes) -> int: } -def _flip_byte_order(t: torch.Tensor) -> torch.Tensor: - return ( - t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype) - ) - - def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b8aaff3d773..fb9de2e445d 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -520,3 +520,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: data = np.flip(data, axis=1) # flip on h dimension data = data[:slice_channels, :, :] return data.astype(np.float32) + + +def _flip_byte_order(t: torch.Tensor) -> torch.Tensor: + return ( + t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype) + ) From 32e6ea686245e5e3c84bf61c2897265c9182b4e5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Jan 2023 14:04:45 +0100 Subject: [PATCH 4/4] remove lazy import --- test/test_datasets_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 992b0478583..4e30dfab2cc 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -227,13 +227,11 @@ def test_verify_str_arg(self): ], ) def test_flip_byte_order(self, dtype, actual_hex, expected_hex): - from torchvision.datasets.mnist import _flip_byte_order - def to_tensor(hex): return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) assert_equal( - _flip_byte_order(to_tensor(actual_hex)), + utils._flip_byte_order(to_tensor(actual_hex)), to_tensor(expected_hex), )