From a296f11911284adb4ea0aaed8e9a16f925a27a03 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 25 Jan 2024 12:23:30 -0800 Subject: [PATCH 01/14] WIP. Tests failing. --- test/test_transforms_v2.py | 26 +++++++++++++++++ torchvision/transforms/_functional_pil.py | 7 +++++ torchvision/transforms/_functional_tensor.py | 13 +++++++++ torchvision/transforms/functional.py | 20 +++++++++++++ torchvision/transforms/transforms.py | 28 +++++++++++++++++++ torchvision/transforms/v2/__init__.py | 1 + torchvision/transforms/v2/_color.py | 16 +++++++++++ .../transforms/v2/functional/__init__.py | 2 ++ .../transforms/v2/functional/_color.py | 22 +++++++++++++++ 9 files changed, 135 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 831a7e3b570..f7acb22eccb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4963,6 +4963,32 @@ def test_random_transform_correctness(self, num_input_channels): assert_equal(actual, expected, rtol=0, atol=1) +class TestGrayscaleToRgb: + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image(self, dtype, device): + check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device)) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_functional(self, make_input): + check_functional(F.grayscale_to_rgb, make_input()) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.rgb_to_grayscale_image, torch.Tensor), + (F._rgb_to_grayscale_image_pil, PIL.Image.Image), + (F.rgb_to_grayscale_image, tv_tensors.Image), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_transform(self, make_input): + check_transform(transforms.GrayscaleToRgb(), make_input(color_space="GRAY")) + + class TestRandomZoomOut: # Tests are light because this largely relies on the already tested `pad` kernels. diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 277848224ac..f82684fafb9 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -348,6 +348,13 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: return img +def grayscale_to_rgb(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.convert("RGB") + + @torch.jit.unused def invert(img: Image.Image) -> Image.Image: if not _is_pil_image(img): diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 88dc9ca21cc..9bccdfe5d9f 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -166,6 +166,19 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img +def grayscale_to_rgb(img: Tensor) -> Tensor: + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + _assert_channels(img, [1, 3]) + + if img.shape[-3] != 1: + raise ValueError("Expected last dimension of input image tensor to be 1, got {}".format(img.shape[-3])) + + s = [-1] * len(img.shape) + s[-3] = 3 + return img.expand(s) + + def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 698942c56af..034466d1316 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1276,6 +1276,26 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return F_t.rgb_to_grayscale(img, num_output_channels) +def grayscale_to_rgb(img: Tensor) -> Tensor: + """Converts grayscale image to 3-channel image by duplicating the value in all three channels. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Note: + Please, note that this method supports only grayscale images as input. For inputs in other color spaces, + please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. + + Returns: + PIL Image or Tensor: RGB Image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(grayscale_to_rgb) + if not isinstance(img, torch.Tensor): + return F_pil.grayscale_to_rgb(img) + return F_t.grayscale_to_rgb(img) + + + def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: """Erase the input Tensor Image with given value. This transform does not support PIL Image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a6e0ce12c0..8ec3e3466af 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -42,6 +42,7 @@ "RandomRotation", "RandomAffine", "Grayscale", + "GrayscaleToRgb", "RandomGrayscale", "RandomPerspective", "RandomErasing", @@ -1576,6 +1577,33 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})" +class GrayscaleToRgb(torch.nn.Module): + """Converts grayscale image to 3 channel RGB image. + If the image is torch Tensor, it is expected + to have [..., 1, H, W] shape, where ... means an arbitrary number of leading dimensions + + Returns: + PIL Image: RGB version of the input image. + """ + + def __init__(self): + super().__init__() + _log_api_usage_once(self) + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscaled image. + """ + return F.grayscale_to_rgb(img) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + class RandomGrayscale(torch.nn.Module): """Randomly convert image to grayscale with a probability of p (default 0.1). If the image is torch Tensor, it is expected diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index dbc0474d307..2cb3ea3eeae 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -9,6 +9,7 @@ from ._color import ( ColorJitter, Grayscale, + GrayscaleToRgb, RandomAdjustSharpness, RandomAutocontrast, RandomChannelPermutation, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index d20953451ab..39605f46ed7 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -54,6 +54,22 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) +class GrayscaleToRgb(Transform): + """Converts grayscale images to RGB images. + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions + """ + + _v1_transform_cls = _transforms.GrayscaleToRgb + + def __init__(self): + super().__init__() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.grayscale_to_rgb, inpt) + + class ColorJitter(Transform): """Randomly change the brightness, contrast, saturation and hue of an image or video. diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 81d5c1b9baf..82f1d69cbba 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -63,6 +63,8 @@ equalize, equalize_image, equalize_video, + grayscale_to_rgb, + grayscale_to_rgb_image, invert, invert_image, invert_video, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index b0189fd95ef..36e0029a3c0 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -61,6 +61,28 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int return _FP.to_grayscale(image, num_output_channels=num_output_channels) +def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details.""" + if torch.jit.is_scripting(): + return grayscale_to_rgb_image(inpt) + + _log_api_usage_once(grayscale_to_rgb) + + kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(grayscale_to_rgb, torch.Tensor) +@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) +def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) + + +@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) +def grayscale_to_rgb_image(image: PIL.Image.Image) -> PIL.Image.Image: + return _FP.to_grayscale(image, num_output_channels=3) + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() From ff7a49c12bc788fa588feaf79dd4607cb159d5ec Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:16:46 -0800 Subject: [PATCH 02/14] Fixed typo --- torchvision/transforms/v2/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 3e6ac96640e..855c387e21e 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -83,7 +83,7 @@ def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) -def grayscale_to_rgb_image(image: PIL.Image.Image) -> PIL.Image.Image: +def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.to_grayscale(image, num_output_channels=3) From 3f1e6114162a3b6828051655d52b9cc33e2b09ab Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:25:43 -0800 Subject: [PATCH 03/14] Fixed typo --- test/test_transforms_v2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ef5f32844ba..1334f0dbff2 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4997,6 +4997,15 @@ def test_functional_signature(self, kernel, input_type): def test_transform(self, make_input): check_transform(transforms.GrayscaleToRgb(), make_input(color_space="GRAY")) + @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.GrayscaleToRgb)]) + def test_image_correctness(self, fn): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + actual = fn(image, num_output_channels=num_output_channels) + expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image))) + + assert_equal(actual, expected, rtol=0, atol=1) + class TestRandomZoomOut: # Tests are light because this largely relies on the already tested `pad` kernels. From e0e056f74c14118b591be58936c70807d3b645b5 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:45:11 -0800 Subject: [PATCH 04/14] Fixed bug. --- test/test_transforms_v2.py | 3 ++- torchvision/transforms/v2/functional/_color.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1334f0dbff2..7f092c8115a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5001,9 +5001,10 @@ def test_transform(self, make_input): def test_image_correctness(self, fn): image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") - actual = fn(image, num_output_channels=num_output_channels) + actual = fn(image) expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image))) + print(f"ahmad: {expected.shape=} {actual.shape=}") assert_equal(actual, expected, rtol=0, atol=1) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 855c387e21e..f2e1d9ad06a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -72,7 +72,7 @@ def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: _log_api_usage_once(grayscale_to_rgb) - kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + kernel = _get_kernel(grayscale_to_rgb, type(inpt)) return kernel(inpt) From 3547f83c4ac6eaf62fd731713c53ac1a1f72f997 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:45:43 -0800 Subject: [PATCH 05/14] Fixed bug. --- test/test_transforms_v2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7f092c8115a..fb8b5be2c48 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5007,6 +5007,14 @@ def test_image_correctness(self, fn): print(f"ahmad: {expected.shape=} {actual.shape=}") assert_equal(actual, expected, rtol=0, atol=1) + def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + output_image = F.grayscale_to_rgb(image) + assert_equal(output_image[0][0][0], output_image[1][0][0]) + output_image[0][0][0] = output_image[0][0][0] + 1 + assert output_image[0][0][0] != output_image[1][0][0] + class TestRandomZoomOut: # Tests are light because this largely relies on the already tested `pad` kernels. From 133a6ba1b82e3c5d0c313cf637dc20adb2e1472e Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:48:58 -0800 Subject: [PATCH 06/14] Add comments. --- torchvision/transforms/v2/functional/_color.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index f2e1d9ad06a..fb1f9a4d174 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -79,11 +79,13 @@ def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(grayscale_to_rgb, torch.Tensor) @_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + # rgb_to_grayscale can be used to add channels so we reuse that function. return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) @_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + # to_grayscale can expand channels from 1 to 3 so we reuse that function. return _FP.to_grayscale(image, num_output_channels=3) From 78db2f656dda508afbdc9a755c41ed9e5cb5cf86 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 1 Feb 2024 15:59:22 -0800 Subject: [PATCH 07/14] removed extra line --- torchvision/transforms/functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 034466d1316..c83bf018adc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1295,7 +1295,6 @@ def grayscale_to_rgb(img: Tensor) -> Tensor: return F_t.grayscale_to_rgb(img) - def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: """Erase the input Tensor Image with given value. This transform does not support PIL Image. From 990fc317864bd9da433a5fe22d18f86dc373ddc7 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 12 Feb 2024 11:23:47 -0800 Subject: [PATCH 08/14] Addressed comments and added another test. --- test/test_transforms_v2.py | 14 +++++++++++++- torchvision/transforms/v2/functional/_color.py | 6 ++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index fb8b5be2c48..2af4dd9d739 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5004,7 +5004,6 @@ def test_image_correctness(self, fn): actual = fn(image) expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image))) - print(f"ahmad: {expected.shape=} {actual.shape=}") assert_equal(actual, expected, rtol=0, atol=1) def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): @@ -5015,6 +5014,19 @@ def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): output_image[0][0][0] = output_image[0][0][0] + 1 assert output_image[0][0][0] != output_image[1][0][0] + def test_rgb_image_is_unchanged(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB") + assert_equal(image.shape[-3], 3) + image[0][0][0] = 0 + image[1][0][0] = 100 + image[2][0][0] = 200 + output_image = F.grayscale_to_rgb(image) + assert output_image[0][0][0] == 0 + assert output_image[1][0][0] == 100 + assert output_image[2][0][0] == 200 + print(image) + print(output_image) + class TestRandomZoomOut: # Tests are light because this largely relies on the already tested `pad` kernels. diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index fb1f9a4d174..3025f876dff 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -79,14 +79,16 @@ def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(grayscale_to_rgb, torch.Tensor) @_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + if image.shape[-3] >= 3: + # Image already has RGB channels. We don't need to do anything. + return image # rgb_to_grayscale can be used to add channels so we reuse that function. return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) @_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: - # to_grayscale can expand channels from 1 to 3 so we reuse that function. - return _FP.to_grayscale(image, num_output_channels=3) + return image.convert(mode="RGB") def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: From dde638e641cd106c13c9ce9c2c5fa631495beb5d Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 12 Feb 2024 13:07:30 -0800 Subject: [PATCH 09/14] Removed support for v1 transform class --- torchvision/transforms/_functional_pil.py | 7 ----- torchvision/transforms/_functional_tensor.py | 13 --------- torchvision/transforms/functional.py | 19 ------------- torchvision/transforms/transforms.py | 28 -------------------- torchvision/transforms/v2/_color.py | 2 -- 5 files changed, 69 deletions(-) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index f82684fafb9..277848224ac 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -348,13 +348,6 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: return img -def grayscale_to_rgb(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - return img.convert("RGB") - - @torch.jit.unused def invert(img: Image.Image) -> Image.Image: if not _is_pil_image(img): diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 9bccdfe5d9f..88dc9ca21cc 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -166,19 +166,6 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img -def grayscale_to_rgb(img: Tensor) -> Tensor: - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - _assert_channels(img, [1, 3]) - - if img.shape[-3] != 1: - raise ValueError("Expected last dimension of input image tensor to be 1, got {}".format(img.shape[-3])) - - s = [-1] * len(img.shape) - s[-3] = 3 - return img.expand(s) - - def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c83bf018adc..698942c56af 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1276,25 +1276,6 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return F_t.rgb_to_grayscale(img, num_output_channels) -def grayscale_to_rgb(img: Tensor) -> Tensor: - """Converts grayscale image to 3-channel image by duplicating the value in all three channels. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions - - Note: - Please, note that this method supports only grayscale images as input. For inputs in other color spaces, - please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. - - Returns: - PIL Image or Tensor: RGB Image. - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(grayscale_to_rgb) - if not isinstance(img, torch.Tensor): - return F_pil.grayscale_to_rgb(img) - return F_t.grayscale_to_rgb(img) - - def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: """Erase the input Tensor Image with given value. This transform does not support PIL Image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 8ec3e3466af..2a6e0ce12c0 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -42,7 +42,6 @@ "RandomRotation", "RandomAffine", "Grayscale", - "GrayscaleToRgb", "RandomGrayscale", "RandomPerspective", "RandomErasing", @@ -1577,33 +1576,6 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})" -class GrayscaleToRgb(torch.nn.Module): - """Converts grayscale image to 3 channel RGB image. - If the image is torch Tensor, it is expected - to have [..., 1, H, W] shape, where ... means an arbitrary number of leading dimensions - - Returns: - PIL Image: RGB version of the input image. - """ - - def __init__(self): - super().__init__() - _log_api_usage_once(self) - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscaled image. - """ - return F.grayscale_to_rgb(img) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}()" - - class RandomGrayscale(torch.nn.Module): """Randomly convert image to grayscale with a probability of p (default 0.1). If the image is torch Tensor, it is expected diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 39605f46ed7..0979afd5837 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -61,8 +61,6 @@ class GrayscaleToRgb(Transform): to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions """ - _v1_transform_cls = _transforms.GrayscaleToRgb - def __init__(self): super().__init__() From dd01a21e93f08f4ee14ead912cdcfc8361604d0b Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 4 Mar 2024 12:27:26 -0800 Subject: [PATCH 10/14] Accepted suggestion that simplifies test code --- test/test_transforms_v2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 2af4dd9d739..42d5f87045e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5017,15 +5017,7 @@ def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): def test_rgb_image_is_unchanged(self): image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB") assert_equal(image.shape[-3], 3) - image[0][0][0] = 0 - image[1][0][0] = 100 - image[2][0][0] = 200 - output_image = F.grayscale_to_rgb(image) - assert output_image[0][0][0] == 0 - assert output_image[1][0][0] == 100 - assert output_image[2][0][0] == 200 - print(image) - print(output_image) + assert_equal(F.grayscale_to_rgb(image), image) class TestRandomZoomOut: From 02526afc628f40f4100593450b2228e07cc9901f Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 15 Mar 2024 08:23:09 -0700 Subject: [PATCH 11/14] Changed names --- docs/source/transforms.rst | 4 +++- test/test_transforms_v2.py | 4 ++-- torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_color.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 056d1589e84..b52914a373c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -347,6 +347,7 @@ Color v2.RandomChannelPermutation v2.RandomPhotometricDistort v2.Grayscale + v2.RGB v2.RandomGrayscale v2.GaussianBlur v2.RandomInvert @@ -364,6 +365,7 @@ Functionals v2.functional.permute_channels v2.functional.rgb_to_grayscale + v2.functional.grayscale_to_rgb v2.functional.to_grayscale v2.functional.gaussian_blur v2.functional.invert @@ -583,7 +585,7 @@ Conversion while performing the conversion, while some may not do any scaling. By scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0, 255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`. - + .. autosummary:: :toctree: generated/ :template: class.rst diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1a247625560..1dd3eae2526 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5012,9 +5012,9 @@ def test_functional_signature(self, kernel, input_type): @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) def test_transform(self, make_input): - check_transform(transforms.GrayscaleToRgb(), make_input(color_space="GRAY")) + check_transform(transforms.RGB(), make_input(color_space="GRAY")) - @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.GrayscaleToRgb)]) + @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)]) def test_image_correctness(self, fn): image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 2cb3ea3eeae..61972994639 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -9,7 +9,7 @@ from ._color import ( ColorJitter, Grayscale, - GrayscaleToRgb, + RGB, RandomAdjustSharpness, RandomAutocontrast, RandomChannelPermutation, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 0979afd5837..ab049b6bf9e 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -54,7 +54,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) -class GrayscaleToRgb(Transform): +class RGB(Transform): """Converts grayscale images to RGB images. If the input is a :class:`torch.Tensor`, it is expected From d24b3222078f1d8ea8d15848cd94d7fe2e49e8b3 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 15 Mar 2024 08:24:34 -0700 Subject: [PATCH 12/14] Tweaked documentation --- torchvision/transforms/v2/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index ab049b6bf9e..8f334105911 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -55,7 +55,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RGB(Transform): - """Converts grayscale images to RGB images. + """Converts images or videos to RGB (if they are already not RGB) If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions From 7331023cbb35b4a716087e51564fce5d7adff962 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 15 Mar 2024 08:32:42 -0700 Subject: [PATCH 13/14] Sorted imports --- torchvision/transforms/v2/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 61972994639..fea39d3cf20 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -9,7 +9,6 @@ from ._color import ( ColorJitter, Grayscale, - RGB, RandomAdjustSharpness, RandomAutocontrast, RandomChannelPermutation, @@ -19,6 +18,7 @@ RandomPhotometricDistort, RandomPosterize, RandomSolarize, + RGB, ) from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( From b2daebd998c9d5ca2a47d054c3ba9133afa883c5 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 15 Mar 2024 08:43:54 -0700 Subject: [PATCH 14/14] Adhere to PEP0257 --- torchvision/transforms/v2/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 8f334105911..49b4a8d8b10 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -55,7 +55,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RGB(Transform): - """Converts images or videos to RGB (if they are already not RGB) + """Convert images or videos to RGB (if they are already not RGB). If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions