From 92d75dfa1760ea128cc0f3ec264a30cb90c759e4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 26 Jan 2023 11:28:14 +0100 Subject: [PATCH 1/3] perform check for single values and two tuples --- test/test_prototype_transforms_consistency.py | 2 +- test/test_transforms.py | 6 ++++++ torchvision/transforms/transforms.py | 11 +++++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 3b69b72dd4f..a77f42930bc 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -312,7 +312,7 @@ def __init__( ArgsKwargs(saturation=(0.8, 0.9)), ArgsKwargs(hue=0.3), ArgsKwargs(hue=(-0.1, 0.2)), - ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6), + ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3), ], closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, ), diff --git a/test/test_transforms.py b/test/test_transforms.py index 0c388dbb519..0340f9f3f15 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1798,6 +1798,12 @@ def test_color_jitter(): color_jitter.__repr__() +@pytest.mark.parametrize("hue", [1, (-1, 1)]) +def test_color_jitter_hue_out_of_bounds(hue): + with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")): + transforms.ColorJitter(hue=hue) + + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_erasing(seed): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index cb2bfdb92a8..88f709ee477 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1195,16 +1195,19 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs if clip_first_on_zero: value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") + value = [float(value[0]), float(value[1])] else: raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: - value = None - return value + return None + else: + return tuple(value) @staticmethod def get_params( From b3108a6e16c1360370470cd427df17e3ff9388e3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 26 Jan 2023 14:11:14 +0100 Subject: [PATCH 2/3] apply fix to prototype as well --- torchvision/prototype/transforms/_color.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0eb20e57764..e439e666678 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -42,12 +42,12 @@ def _check_input( value = [center - value, center + value] if clip_first_on_zero: value[0] = max(value[0], 0.0) - elif isinstance(value, collections.abc.Sequence) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") - else: + elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2): raise TypeError(f"{name} should be a single number or a sequence with length 2.") + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod From 57edc69700de232c6fbd32b5e587b4f90610b58b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 26 Jan 2023 15:39:20 +0100 Subject: [PATCH 3/3] improve error message --- torchvision/prototype/transforms/_color.py | 2 +- torchvision/transforms/transforms.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index e439e666678..3c787a78de4 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -46,7 +46,7 @@ def _check_input( raise TypeError(f"{name} should be a single number or a sequence with length 2.") if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") + raise ValueError(f"{name} values should be between {bound}, but got {value}.") return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 88f709ee477..d667be17666 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1200,7 +1200,7 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") + raise ValueError(f"{name} values should be between {bound}, but got {value}.") # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing