diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index b416dae20e0..f76c0f93d5e 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -317,7 +317,7 @@ def __init__( ArgsKwargs(saturation=(0.8, 0.9)), ArgsKwargs(hue=0.3), ArgsKwargs(hue=(-0.1, 0.2)), - ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6), + ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3), ], closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, ), diff --git a/test/test_transforms.py b/test/test_transforms.py index 0c388dbb519..0340f9f3f15 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1798,6 +1798,12 @@ def test_color_jitter(): color_jitter.__repr__() +@pytest.mark.parametrize("hue", [1, (-1, 1)]) +def test_color_jitter_hue_out_of_bounds(hue): + with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")): + transforms.ColorJitter(hue=hue) + + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_erasing(seed): diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 6ab997b1e93..17b02e36953 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -77,12 +77,12 @@ def _check_input( value = [center - value, center + value] if clip_first_on_zero: value[0] = max(value[0], 0.0) - elif isinstance(value, collections.abc.Sequence) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") - else: + elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2): raise TypeError(f"{name} should be a single number or a sequence with length 2.") + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}, but got {value}.") + return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 62e36a06f98..573791b4151 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1195,16 +1195,19 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs if clip_first_on_zero: value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}") + value = [float(value[0]), float(value[1])] else: raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}, but got {value}.") + # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: - value = None - return value + return None + else: + return tuple(value) @staticmethod def get_params(