Skip to content

Commit 8d3e5dd

Browse files
committed
Add customisable sampling weights
1 parent 4dbb3c5 commit 8d3e5dd

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

monai/transforms/compose.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818
from collections.abc import Callable, Mapping, Sequence
19-
from typing import Any
19+
from typing import Any, List, Optional
2020

2121
import numpy as np
2222

@@ -383,6 +383,7 @@ class SomeOf(Compose):
383383
max_num_transforms: maximum number of transforms to sample. Defaults to `3`.
384384
fixed: whether to sample exactly `max_num_transforms` transforms, or up to it. Defaults to `False`.
385385
replace: whether to sample with replacement. Defaults to `False`.
386+
weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform).
386387
"""
387388

388389
def __init__(
@@ -395,6 +396,7 @@ def __init__(
395396
max_num_transforms: int = 3,
396397
fixed: bool = True,
397398
replace: bool = False,
399+
weights: Optional[List[int]] = None,
398400
) -> None:
399401
super().__init__(transforms, map_items, unpack_items, log_stats)
400402
if transforms is None:
@@ -404,6 +406,29 @@ def __init__(
404406
self.max_num_transforms = min(self.n_transforms, max_num_transforms)
405407
self.fixed = fixed
406408
self.replace = replace
409+
self.weights = self._normalize_probabilities(weights)
410+
411+
def _normalize_probabilities(self, weights):
412+
if weights is None or self.n_transforms == 0:
413+
return None
414+
415+
weights = np.array(weights)
416+
417+
n_weights = len(weights)
418+
if n_weights != self.n_transforms:
419+
raise ValueError(
420+
f"The number of weights specified must be equal to the number of transforms provided. Expected: {self.n_transforms}, got: {n_weights}."
421+
)
422+
423+
if np.any(weights < 0):
424+
raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.")
425+
426+
if np.all(weights == 0):
427+
raise ValueError(f"At least one probability must be greater than zero, got {weights}.")
428+
429+
weights = weights / weights.sum()
430+
431+
return ensure_tuple(list(weights))
407432

408433
def __call__(self, data):
409434
if self.n_transforms == 0:
@@ -415,7 +440,7 @@ def __call__(self, data):
415440
else self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)
416441
)
417442

418-
applied_order = self.R.choice(self.n_transforms, sample_size, replace=self.replace).tolist()
443+
applied_order = self.R.choice(self.n_transforms, sample_size, replace=self.replace, p=self.weights).tolist()
419444
for i in applied_order:
420445
data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats)
421446

tests/test_some_of.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,22 @@ def test_inverse(self, transform, invertible, use_metatensor):
170170
# if not invertible, should not change the data
171171
self.assertDictEqual(fwd_data[i], _fwd_inv_data)
172172

173+
def test_normalize_weights(self):
174+
tr = SomeOf((A(), B(), C()), fixed=True, max_num_transforms=1, weights=(1, 2, 1))
175+
self.assertTupleEqual(tr.weights, (0.25, 0.5, 0.25))
176+
177+
tr = SomeOf((), fixed=True, max_num_transforms=1, weights=(1, 2, 1))
178+
self.assertIsNone(tr.weights)
179+
180+
def test_no_weights_arg(self):
181+
tr = SomeOf((A(), B(), C(), D()), fixed=True, max_num_transforms=1)
182+
self.assertIsNone(tr.weights)
183+
184+
def test_bad_weights(self):
185+
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(1, 2))
186+
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(0, 0, 0))
187+
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(-1, 1, 1))
188+
173189

174190
if __name__ == "__main__":
175191
unittest.main()

0 commit comments

Comments
 (0)