Skip to content

Commit a65ef4a

Browse files
Merge branch 'Project-MONAI:dev' into 4980-get-wsi-at-mpp
2 parents ae704f3 + e5bebfc commit a65ef4a

File tree

9 files changed

+436
-5
lines changed

9 files changed

+436
-5
lines changed

docs/source/transforms.rst

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,27 @@ Post-processing
661661
:members:
662662
:special-members: __call__
663663

664+
Regularization
665+
^^^^^^^^^^^^^^
666+
667+
`CutMix`
668+
""""""""
669+
.. autoclass:: CutMix
670+
:members:
671+
:special-members: __call__
672+
673+
`CutOut`
674+
""""""""
675+
.. autoclass:: CutOut
676+
:members:
677+
:special-members: __call__
678+
679+
`MixUp`
680+
"""""""
681+
.. autoclass:: MixUp
682+
:members:
683+
:special-members: __call__
684+
664685
Signal
665686
^^^^^^^
666687

@@ -1707,6 +1728,27 @@ Post-processing (Dict)
17071728
:members:
17081729
:special-members: __call__
17091730

1731+
Regularization (Dict)
1732+
^^^^^^^^^^^^^^^^^^^^^
1733+
1734+
`CutMixd`
1735+
"""""""""
1736+
.. autoclass:: CutMixd
1737+
:members:
1738+
:special-members: __call__
1739+
1740+
`CutOutd`
1741+
"""""""""
1742+
.. autoclass:: CutOutd
1743+
:members:
1744+
:special-members: __call__
1745+
1746+
`MixUpd`
1747+
""""""""
1748+
.. autoclass:: MixUpd
1749+
:members:
1750+
:special-members: __call__
1751+
17101752
Signal (Dict)
17111753
^^^^^^^^^^^^^
17121754

docs/source/transforms_idx.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ Post-processing
7474
post.array
7575
post.dictionary
7676

77+
Regularization
78+
^^^^^^^^^^^^^^
79+
80+
.. autosummary::
81+
:toctree: _gen
82+
:nosignatures:
83+
84+
regularization.array
85+
regularization.dictionary
86+
7787
Signal
7888
^^^^^^
7989

monai/bundle/config_item.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object:
289289
mode = self.get_config().get("_mode_", CompInitMode.DEFAULT)
290290
args = self.resolve_args()
291291
args.update(kwargs)
292-
try:
293-
return instantiate(modname, mode, **args)
294-
except Exception as e:
295-
raise RuntimeError(f"Failed to instantiate {self}") from e
292+
return instantiate(modname, mode, **args)
296293

297294

298295
class ConfigExpression(ConfigItem):

monai/transforms/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,18 @@
336336
VoteEnsembled,
337337
VoteEnsembleDict,
338338
)
339+
from .regularization.array import CutMix, CutOut, MixUp
340+
from .regularization.dictionary import (
341+
CutMixd,
342+
CutMixD,
343+
CutMixDict,
344+
CutOutd,
345+
CutOutD,
346+
CutOutDict,
347+
MixUpd,
348+
MixUpD,
349+
MixUpDict,
350+
)
339351
from .signal.array import (
340352
SignalContinuousWavelet,
341353
SignalFillEmpty,
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from abc import abstractmethod
15+
from math import ceil, sqrt
16+
17+
import torch
18+
19+
from ..transform import RandomizableTransform
20+
21+
__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]
22+
23+
24+
class Mixer(RandomizableTransform):
25+
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
26+
"""
27+
Mixer is a base class providing the basic logic for the mixup-class of
28+
augmentations. In all cases, we need to sample the mixing weights for each
29+
sample (lambda in the notation used in the papers). Also, pairs of samples
30+
being mixed are picked by randomly shuffling the batch samples.
31+
32+
Args:
33+
batch_size (int): number of samples per batch. That is, samples are expected tp
34+
be of size batchsize x channels [x depth] x height x width.
35+
alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha)
36+
distribution. Defaults to 1.0, the uniform distribution.
37+
"""
38+
super().__init__()
39+
if alpha <= 0:
40+
raise ValueError(f"Expected positive number, but got {alpha = }")
41+
self.alpha = alpha
42+
self.batch_size = batch_size
43+
44+
@abstractmethod
45+
def apply(self, data: torch.Tensor):
46+
raise NotImplementedError()
47+
48+
def randomize(self, data=None) -> None:
49+
"""
50+
Sometimes you need may to apply the same transform to different tensors.
51+
The idea is to get a sample and then apply it with apply() as often
52+
as needed. You need to call this method everytime you apply the transform to a new
53+
batch.
54+
"""
55+
self._params = (
56+
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
57+
self.R.permutation(self.batch_size),
58+
)
59+
60+
61+
class MixUp(Mixer):
62+
"""MixUp as described in:
63+
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
64+
mixup: Beyond Empirical Risk Minimization, ICLR 2018
65+
66+
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
67+
documentation for details on the constructor parameters.
68+
"""
69+
70+
def apply(self, data: torch.Tensor):
71+
weight, perm = self._params
72+
nsamples, *dims = data.shape
73+
if len(weight) != nsamples:
74+
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
75+
76+
if len(dims) not in [3, 4]:
77+
raise ValueError("Unexpected number of dimensions")
78+
79+
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
80+
return mixweight * data + (1 - mixweight) * data[perm, ...]
81+
82+
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
83+
self.randomize()
84+
if labels is None:
85+
return self.apply(data)
86+
return self.apply(data), self.apply(labels)
87+
88+
89+
class CutMix(Mixer):
90+
"""CutMix augmentation as described in:
91+
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
92+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
93+
ICCV 2019
94+
95+
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
96+
documentation for details on the constructor parameters. Here, alpha not only determines
97+
the mixing weight but also the size of the random rectangles used during for mixing.
98+
Please refer to the paper for details.
99+
100+
The most common use case is something close to:
101+
102+
.. code-block:: python
103+
104+
cm = CutMix(batch_size=8, alpha=0.5)
105+
for batch in loader:
106+
images, labels = batch
107+
augimg, auglabels = cm(images, labels)
108+
output = model(augimg)
109+
loss = loss_function(output, auglabels)
110+
...
111+
112+
"""
113+
114+
def apply(self, data: torch.Tensor):
115+
weights, perm = self._params
116+
nsamples, _, *dims = data.shape
117+
if len(weights) != nsamples:
118+
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
119+
120+
mask = torch.ones_like(data)
121+
for s, weight in enumerate(weights):
122+
coords = [torch.randint(0, d, size=(1,)) for d in dims]
123+
lengths = [d * sqrt(1 - weight) for d in dims]
124+
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
125+
mask[s][idx] = 0
126+
127+
return mask * data + (1 - mask) * data[perm, ...]
128+
129+
def apply_on_labels(self, labels: torch.Tensor):
130+
weights, perm = self._params
131+
nsamples, *dims = labels.shape
132+
if len(weights) != nsamples:
133+
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
134+
135+
mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
136+
return mixweight * labels + (1 - mixweight) * labels[perm, ...]
137+
138+
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
139+
self.randomize()
140+
augmented = self.apply(data)
141+
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
142+
143+
144+
class CutOut(Mixer):
145+
"""Cutout as described in the paper:
146+
Terrance DeVries, Graham W. Taylor.
147+
Improved Regularization of Convolutional Neural Networks with Cutout,
148+
arXiv:1708.04552
149+
150+
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
151+
documentation for details on the constructor parameters. Here, alpha not only determines
152+
the mixing weight but also the size of the random rectangles being cut put.
153+
Please refer to the paper for details.
154+
"""
155+
156+
def apply(self, data: torch.Tensor):
157+
weights, _ = self._params
158+
nsamples, _, *dims = data.shape
159+
if len(weights) != nsamples:
160+
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
161+
162+
mask = torch.ones_like(data)
163+
for s, weight in enumerate(weights):
164+
coords = [torch.randint(0, d, size=(1,)) for d in dims]
165+
lengths = [d * sqrt(1 - weight) for d in dims]
166+
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
167+
mask[s][idx] = 0
168+
169+
return mask * data
170+
171+
def __call__(self, data: torch.Tensor):
172+
self.randomize()
173+
return self.apply(data)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from monai.config import KeysCollection
15+
from monai.utils.misc import ensure_tuple
16+
17+
from ..transform import MapTransform
18+
from .array import CutMix, CutOut, MixUp
19+
20+
__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
21+
22+
23+
class MixUpd(MapTransform):
24+
"""
25+
Dictionary-based version :py:class:`monai.transforms.MixUp`.
26+
27+
Notice that the mixup transformation will be the same for all entries
28+
for consistency, i.e. images and labels must be applied the same augmenation.
29+
"""
30+
31+
def __init__(
32+
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
33+
) -> None:
34+
super().__init__(keys, allow_missing_keys)
35+
self.mixup = MixUp(batch_size, alpha)
36+
37+
def __call__(self, data):
38+
self.mixup.randomize()
39+
result = dict(data)
40+
for k in self.keys:
41+
result[k] = self.mixup.apply(data[k])
42+
return result
43+
44+
45+
class CutMixd(MapTransform):
46+
"""
47+
Dictionary-based version :py:class:`monai.transforms.CutMix`.
48+
49+
Notice that the mixture weights will be the same for all entries
50+
for consistency, i.e. images and labels must be aggregated with the same weights,
51+
but the random crops are not.
52+
"""
53+
54+
def __init__(
55+
self,
56+
keys: KeysCollection,
57+
batch_size: int,
58+
label_keys: KeysCollection | None = None,
59+
alpha: float = 1.0,
60+
allow_missing_keys: bool = False,
61+
) -> None:
62+
super().__init__(keys, allow_missing_keys)
63+
self.mixer = CutMix(batch_size, alpha)
64+
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
65+
66+
def __call__(self, data):
67+
self.mixer.randomize()
68+
result = dict(data)
69+
for k in self.keys:
70+
result[k] = self.mixer.apply(data[k])
71+
for k in self.label_keys:
72+
result[k] = self.mixer.apply_on_labels(data[k])
73+
return result
74+
75+
76+
class CutOutd(MapTransform):
77+
"""
78+
Dictionary-based version :py:class:`monai.transforms.CutOut`.
79+
80+
Notice that the cutout is different for every entry in the dictionary.
81+
"""
82+
83+
def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
84+
super().__init__(keys, allow_missing_keys)
85+
self.cutout = CutOut(batch_size)
86+
87+
def __call__(self, data):
88+
result = dict(data)
89+
self.cutout.randomize()
90+
for k in self.keys:
91+
result[k] = self.cutout(data[k])
92+
return result
93+
94+
95+
MixUpD = MixUpDict = MixUpd
96+
CutMixD = CutMixDict = CutMixd
97+
CutOutD = CutOutDict = CutOutd

monai/utils/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
272272
return pdb.runcall(component, **kwargs)
273273
except Exception as e:
274274
raise RuntimeError(
275-
f"Failed to instantiate component '{__path}' with kwargs: {kwargs}"
275+
f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}"
276276
f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode."
277277
) from e
278278

0 commit comments

Comments
 (0)