Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
Expand Down
6 changes: 6 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,12 @@ def test_color_jitter():
color_jitter.__repr__()


@pytest.mark.parametrize("hue", [1, (-1, 1)])
def test_color_jitter_hue_out_of_bounds(hue):
with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
transforms.ColorJitter(hue=hue)


@pytest.mark.parametrize("seed", range(10))
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_erasing(seed):
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def _check_input(
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
else:
elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
raise TypeError(f"{name} should be a single number or a sequence with length 2.")

if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")

return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,16 +1195,19 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")

if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")

# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
return None
else:
return tuple(value)

@staticmethod
def get_params(
Expand Down