1616
1717import warnings
1818from collections .abc import Callable , Mapping , Sequence
19- from typing import Any
19+ from typing import Any , List , Optional
2020
2121import numpy as np
2222
@@ -383,6 +383,7 @@ class SomeOf(Compose):
383383 max_num_transforms: maximum number of transforms to sample. Defaults to `3`.
384384 fixed: whether to sample exactly `max_num_transforms` transforms, or up to it. Defaults to `False`.
385385 replace: whether to sample with replacement. Defaults to `False`.
386+ weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform).
386387 """
387388
388389 def __init__ (
@@ -395,6 +396,7 @@ def __init__(
395396 max_num_transforms : int = 3 ,
396397 fixed : bool = True ,
397398 replace : bool = False ,
399+ weights : Optional [List [int ]] = None ,
398400 ) -> None :
399401 super ().__init__ (transforms , map_items , unpack_items , log_stats )
400402 if transforms is None :
@@ -404,6 +406,29 @@ def __init__(
404406 self .max_num_transforms = min (self .n_transforms , max_num_transforms )
405407 self .fixed = fixed
406408 self .replace = replace
409+ self .weights = self ._normalize_probabilities (weights )
410+
411+ def _normalize_probabilities (self , weights ):
412+ if weights is None or self .n_transforms == 0 :
413+ return None
414+
415+ weights = np .array (weights )
416+
417+ n_weights = len (weights )
418+ if n_weights != self .n_transforms :
419+ raise ValueError (
420+ f"The number of weights specified must be equal to the number of transforms provided. Expected: { self .n_transforms } , got: { n_weights } ."
421+ )
422+
423+ if np .any (weights < 0 ):
424+ raise ValueError (f"Probabilities must be greater than or equal to zero, got { weights } ." )
425+
426+ if np .all (weights == 0 ):
427+ raise ValueError (f"At least one probability must be greater than zero, got { weights } ." )
428+
429+ weights = weights / weights .sum ()
430+
431+ return ensure_tuple (list (weights ))
407432
408433 def __call__ (self , data ):
409434 if self .n_transforms == 0 :
@@ -415,7 +440,7 @@ def __call__(self, data):
415440 else self .R .randint (self .min_num_transforms , self .max_num_transforms + 1 )
416441 )
417442
418- applied_order = self .R .choice (self .n_transforms , sample_size , replace = self .replace ).tolist ()
443+ applied_order = self .R .choice (self .n_transforms , sample_size , replace = self .replace , p = self . weights ).tolist ()
419444 for i in applied_order :
420445 data = apply_transform (self .transforms [i ], data , self .map_items , self .unpack_items , self .log_stats )
421446
0 commit comments