Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __init__(
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
Expand Down
6 changes: 6 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,12 @@ def test_color_jitter():
color_jitter.__repr__()


@pytest.mark.parametrize("hue", [1, (-1, 1)])
def test_color_jitter_hue_out_of_bounds(hue):
with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
transforms.ColorJitter(hue=hue)


@pytest.mark.parametrize("seed", range(10))
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_erasing(seed):
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def _check_input(
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
else:
elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
raise TypeError(f"{name} should be a single number or a sequence with length 2.")

if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")

return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,16 +1195,19 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")

if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")

# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
return None
else:
return tuple(value)

@staticmethod
def get_params(
Expand Down