Skip to content

Commit 53a2658

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] add PermuteChannels transform (#7624)
Reviewed By: matteobettini Differential Revision: D48642277 fbshipit-source-id: 44b1154ab894869f014e2ad2e8367832085d948d
1 parent 051b78f commit 53a2658

File tree

7 files changed

+151
-18
lines changed

7 files changed

+151
-18
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ Color
155155

156156
ColorJitter
157157
v2.ColorJitter
158+
v2.RandomChannelPermutation
158159
v2.RandomPhotometricDistort
159160
Grayscale
160161
v2.Grayscale

test/test_transforms_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class TestSmoke:
124124
(transforms.RandomEqualize(p=1.0), None),
125125
(transforms.RandomGrayscale(p=1.0), None),
126126
(transforms.RandomInvert(p=1.0), None),
127+
(transforms.RandomChannelPermutation(), None),
127128
(transforms.RandomPhotometricDistort(p=1.0), None),
128129
(transforms.RandomPosterize(bits=4, p=1.0), None),
129130
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),

test/test_transforms_v2_refactored.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,3 +2280,61 @@ def resize_my_datapoint():
22802280
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
22812281

22822282
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
2283+
2284+
2285+
class TestPermuteChannels:
2286+
_DEFAULT_PERMUTATION = [2, 0, 1]
2287+
2288+
@pytest.mark.parametrize(
2289+
("kernel", "make_input"),
2290+
[
2291+
(F.permute_channels_image_tensor, make_image_tensor),
2292+
# FIXME
2293+
# check_kernel does not support PIL kernel, but it should
2294+
(F.permute_channels_image_tensor, make_image),
2295+
(F.permute_channels_video, make_video),
2296+
],
2297+
)
2298+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2299+
@pytest.mark.parametrize("device", cpu_and_cuda())
2300+
def test_kernel(self, kernel, make_input, dtype, device):
2301+
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
2302+
2303+
@pytest.mark.parametrize(
2304+
("kernel", "make_input"),
2305+
[
2306+
(F.permute_channels_image_tensor, make_image_tensor),
2307+
(F.permute_channels_image_pil, make_image_pil),
2308+
(F.permute_channels_image_tensor, make_image),
2309+
(F.permute_channels_video, make_video),
2310+
],
2311+
)
2312+
def test_dispatcher(self, kernel, make_input):
2313+
check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION)
2314+
2315+
@pytest.mark.parametrize(
2316+
("kernel", "input_type"),
2317+
[
2318+
(F.permute_channels_image_tensor, torch.Tensor),
2319+
(F.permute_channels_image_pil, PIL.Image.Image),
2320+
(F.permute_channels_image_tensor, datapoints.Image),
2321+
(F.permute_channels_video, datapoints.Video),
2322+
],
2323+
)
2324+
def test_dispatcher_signature(self, kernel, input_type):
2325+
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
2326+
2327+
def reference_image_correctness(self, image, permutation):
2328+
channel_images = image.split(1, dim=-3)
2329+
permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
2330+
return datapoints.Image(torch.concat(permuted_channel_images, dim=-3))
2331+
2332+
@pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
2333+
@pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
2334+
def test_image_correctness(self, permutation, batch_dims):
2335+
image = make_image(batch_dims=batch_dims)
2336+
2337+
actual = F.permute_channels(image, permutation=permutation)
2338+
expected = self.reference_image_correctness(image, permutation=permutation)
2339+
2340+
torch.testing.assert_close(actual, expected)

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Grayscale,
1212
RandomAdjustSharpness,
1313
RandomAutocontrast,
14+
RandomChannelPermutation,
1415
RandomEqualize,
1516
RandomGrayscale,
1617
RandomInvert,

torchvision/transforms/v2/_color.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,27 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
177177
return output
178178

179179

180-
# TODO: This class seems to be untested
180+
class RandomChannelPermutation(Transform):
181+
"""[BETA] Randomly permute the channels of an image or video
182+
183+
.. v2betastatus:: RandomChannelPermutation transform
184+
"""
185+
186+
_transformed_types = (
187+
datapoints.Image,
188+
PIL.Image.Image,
189+
is_simple_tensor,
190+
datapoints.Video,
191+
)
192+
193+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
194+
num_channels, *_ = query_chw(flat_inputs)
195+
return dict(permutation=torch.randperm(num_channels))
196+
197+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
198+
return F.permute_channels(inpt, params["permutation"])
199+
200+
181201
class RandomPhotometricDistort(Transform):
182202
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
183203
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
@@ -241,21 +261,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
241261
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
242262
return params
243263

244-
def _permute_channels(
245-
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
246-
) -> Union[datapoints._ImageType, datapoints._VideoType]:
247-
orig_inpt = inpt
248-
if isinstance(orig_inpt, PIL.Image.Image):
249-
inpt = F.pil_to_tensor(inpt)
250-
251-
# TODO: Find a better fix than as_subclass???
252-
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
253-
254-
if isinstance(orig_inpt, PIL.Image.Image):
255-
output = F.to_image_pil(output)
256-
257-
return output
258-
259264
def _transform(
260265
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
261266
) -> Union[datapoints._ImageType, datapoints._VideoType]:
@@ -270,7 +275,7 @@ def _transform(
270275
if params["contrast_factor"] is not None and not params["contrast_before"]:
271276
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
272277
if params["channel_permutation"] is not None:
273-
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
278+
inpt = F.permute_channels(inpt, permutation=params["channel_permutation"])
274279
return inpt
275280

276281

torchvision/transforms/v2/functional/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
invert_image_pil,
6363
invert_image_tensor,
6464
invert_video,
65+
permute_channels,
66+
permute_channels_image_pil,
67+
permute_channels_image_tensor,
68+
permute_channels_video,
6569
posterize,
6670
posterize_image_pil,
6771
posterize_image_tensor,

torchvision/transforms/v2/functional/_color.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import List, Union
22

33
import PIL.Image
44
import torch
@@ -10,6 +10,8 @@
1010
from torchvision.utils import _log_api_usage_once
1111

1212
from ._misc import _num_value_bits, to_dtype_image_tensor
13+
14+
from ._type_conversion import pil_to_tensor, to_image_pil
1315
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
1416

1517

@@ -641,3 +643,64 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
641643
@_register_kernel_internal(invert, datapoints.Video)
642644
def invert_video(video: torch.Tensor) -> torch.Tensor:
643645
return invert_image_tensor(video)
646+
647+
648+
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
649+
def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT:
650+
"""Permute the channels of the input according to the given permutation.
651+
652+
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
653+
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
654+
655+
Example:
656+
>>> rgb_image = torch.rand(3, 256, 256)
657+
>>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0])
658+
659+
Args:
660+
permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the
661+
channel index in the input and the value determines the channel index in the output. For example,
662+
``permutation=[2, 0 , 1]``
663+
664+
- takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``,
665+
- takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and
666+
- takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``.
667+
668+
Raises:
669+
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
670+
"""
671+
if torch.jit.is_scripting():
672+
return permute_channels_image_tensor(inpt, permutation=permutation)
673+
674+
_log_api_usage_once(permute_channels)
675+
676+
kernel = _get_kernel(permute_channels, type(inpt))
677+
return kernel(inpt, permutation=permutation)
678+
679+
680+
@_register_kernel_internal(permute_channels, torch.Tensor)
681+
@_register_kernel_internal(permute_channels, datapoints.Image)
682+
def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
683+
shape = image.shape
684+
num_channels, height, width = shape[-3:]
685+
686+
if len(permutation) != num_channels:
687+
raise ValueError(
688+
f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}"
689+
)
690+
691+
if image.numel() == 0:
692+
return image
693+
694+
image = image.reshape(-1, num_channels, height, width)
695+
image = image[:, permutation, :, :]
696+
return image.reshape(shape)
697+
698+
699+
@_register_kernel_internal(permute_channels, PIL.Image.Image)
700+
def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
701+
return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation))
702+
703+
704+
@_register_kernel_internal(permute_channels, datapoints.Video)
705+
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
706+
return permute_channels_image_tensor(video, permutation=permutation)

0 commit comments

Comments
 (0)