1313
1414import unittest
1515
16- from monai .transforms import Transform
17- from monai .transforms .compose import SomeOf
16+ from parameterized import parameterized
17+
18+ from monai .data import MetaTensor
19+ from monai .transforms import SomeOf , TraceableTransform , Transform
20+ from monai .transforms .compose import Compose , SomeOf
1821from monai .utils import set_determinism
22+ from monai .utils .enums import TraceKeys
23+ from tests .test_one_of import A , B , C , Inv , NonInv , X , Y
1924
2025
2126class A (Transform ):
@@ -33,6 +38,19 @@ def __call__(self, x):
3338 return 5 * x
3439
3540
41+ class D (Transform ):
42+ def __call__ (self , x ):
43+ return 7 * x
44+
45+
46+ KEYS = ["x" , "y" ]
47+ TEST_COMPOUND = [
48+ (SomeOf ((A (), B (), C ()), fixed = True , max_num_transforms = 3 ), 2 * 3 * 5 ),
49+ (Compose ((SomeOf ((A (), B (), C ()), fixed = True , max_num_transforms = 3 ), D ())), 2 * 3 * 5 * 7 ),
50+ (SomeOf ((A (), B (), C (), Compose ((D ()))), fixed = True , max_num_transforms = 4 ), 2 * 3 * 5 * 7 ),
51+ ]
52+
53+
3654class TestSomeOf (unittest .TestCase ):
3755 def setUp (self ):
3856 set_determinism (seed = 0 )
@@ -90,6 +108,11 @@ def test_unfixed(self):
90108 for i in range (4 ):
91109 self .assertAlmostEqual (subset_size_counts [i ] / iterations , 0.25 , delta = 0.01 )
92110
111+ @parameterized .expand (TEST_COMPOUND )
112+ def test_compound_pipeline (self , transform , expected_value ):
113+ output = transform (1 )
114+ self .assertEqual (output , expected_value )
115+
93116
94117if __name__ == "__main__" :
95118 unittest .main ()
0 commit comments