|
1 | | -import math |
2 | 1 | from typing import Any, cast, Dict, List, Optional, Tuple, Union |
3 | 2 |
|
4 | 3 | import PIL.Image |
|
9 | 8 | from torchvision.prototype import datapoints as proto_datapoints |
10 | 9 | from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform |
11 | 10 |
|
12 | | -from torchvision.transforms.v2._transform import _RandomApplyTransform |
13 | 11 | from torchvision.transforms.v2.functional._geometry import _check_interpolation |
14 | | -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size |
15 | | - |
16 | | - |
17 | | -class _BaseMixUpCutMix(_RandomApplyTransform): |
18 | | - def __init__(self, alpha: float, p: float = 0.5) -> None: |
19 | | - super().__init__(p=p) |
20 | | - self.alpha = alpha |
21 | | - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) |
22 | | - |
23 | | - def _check_inputs(self, flat_inputs: List[Any]) -> None: |
24 | | - if not ( |
25 | | - has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) |
26 | | - and has_any(flat_inputs, proto_datapoints.OneHotLabel) |
27 | | - ): |
28 | | - raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") |
29 | | - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label): |
30 | | - raise TypeError( |
31 | | - f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." |
32 | | - ) |
33 | | - |
34 | | - def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel: |
35 | | - if inpt.ndim < 2: |
36 | | - raise ValueError("Need a batch of one hot labels") |
37 | | - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) |
38 | | - return proto_datapoints.OneHotLabel.wrap_like(inpt, output) |
39 | | - |
40 | | - |
41 | | -class RandomMixUp(_BaseMixUpCutMix): |
42 | | - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
43 | | - return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] |
44 | | - |
45 | | - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
46 | | - lam = params["lam"] |
47 | | - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): |
48 | | - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 |
49 | | - if inpt.ndim < expected_ndim: |
50 | | - raise ValueError("The transform expects a batched input") |
51 | | - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) |
52 | | - |
53 | | - if isinstance(inpt, (datapoints.Image, datapoints.Video)): |
54 | | - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] |
55 | | - |
56 | | - return output |
57 | | - elif isinstance(inpt, proto_datapoints.OneHotLabel): |
58 | | - return self._mixup_onehotlabel(inpt, lam) |
59 | | - else: |
60 | | - return inpt |
61 | | - |
62 | | - |
63 | | -class RandomCutMix(_BaseMixUpCutMix): |
64 | | - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
65 | | - lam = float(self._dist.sample(())) # type: ignore[arg-type] |
66 | | - |
67 | | - H, W = query_size(flat_inputs) |
68 | | - |
69 | | - r_x = torch.randint(W, ()) |
70 | | - r_y = torch.randint(H, ()) |
71 | | - |
72 | | - r = 0.5 * math.sqrt(1.0 - lam) |
73 | | - r_w_half = int(r * W) |
74 | | - r_h_half = int(r * H) |
75 | | - |
76 | | - x1 = int(torch.clamp(r_x - r_w_half, min=0)) |
77 | | - y1 = int(torch.clamp(r_y - r_h_half, min=0)) |
78 | | - x2 = int(torch.clamp(r_x + r_w_half, max=W)) |
79 | | - y2 = int(torch.clamp(r_y + r_h_half, max=H)) |
80 | | - box = (x1, y1, x2, y2) |
81 | | - |
82 | | - lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) |
83 | | - |
84 | | - return dict(box=box, lam_adjusted=lam_adjusted) |
85 | | - |
86 | | - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
87 | | - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): |
88 | | - box = params["box"] |
89 | | - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 |
90 | | - if inpt.ndim < expected_ndim: |
91 | | - raise ValueError("The transform expects a batched input") |
92 | | - x1, y1, x2, y2 = box |
93 | | - rolled = inpt.roll(1, 0) |
94 | | - output = inpt.clone() |
95 | | - output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] |
96 | | - |
97 | | - if isinstance(inpt, (datapoints.Image, datapoints.Video)): |
98 | | - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] |
99 | | - |
100 | | - return output |
101 | | - elif isinstance(inpt, proto_datapoints.OneHotLabel): |
102 | | - lam_adjusted = params["lam_adjusted"] |
103 | | - return self._mixup_onehotlabel(inpt, lam_adjusted) |
104 | | - else: |
105 | | - return inpt |
| 12 | +from torchvision.transforms.v2.utils import is_simple_tensor |
106 | 13 |
|
107 | 14 |
|
108 | 15 | class SimpleCopyPaste(Transform): |
|
0 commit comments