1313
1414import sys
1515import unittest
16+ from copy import deepcopy
17+
18+ from parameterized import parameterized
19+
20+ import numpy as np
21+ import torch
1622
1723from monai .data import DataLoader , Dataset
18- from monai .transforms import AddChannel , Compose
24+ from monai .transforms import AddChannel , Compose , Flip , Rotate90 , Zoom , NormalizeIntensity , Rotate , Rotated
1925from monai .transforms .transform import Randomizable
2026from monai .utils import set_determinism
2127
@@ -56,8 +62,12 @@ def b(d):
5662 d ["b" ] += 1
5763 return d
5864
59- c = Compose ([a , b , a , b , a ])
60- self .assertDictEqual (c ({"a" : 0 , "b" : 0 }), {"a" : 3 , "b" : 2 })
65+ transforms = [a , b , a , b , a ]
66+ data = {"a" : 0 , "b" : 0 }
67+ expected = {"a" : 3 , "b" : 2 }
68+
69+ self .assertDictEqual (Compose (transforms )(data ), expected )
70+ self .assertDictEqual (Compose .execute (data , transforms ), expected )
6171
6272 def test_list_dict_compose (self ):
6373 def a (d ): # transform to handle dict data
@@ -76,10 +86,15 @@ def c(d): # transform to handle dict data
7686 d ["c" ] += 1
7787 return d
7888
79- transforms = Compose ([a , a , b , c , c ])
80- value = transforms ({"a" : 0 , "b" : 0 , "c" : 0 })
89+ transforms = [a , a , b , c , c ]
90+ data = {"a" : 0 , "b" : 0 , "c" : 0 }
91+ expected = {"a" : 2 , "b" : 1 , "c" : 2 }
92+ value = Compose (transforms )(data )
93+ for item in value :
94+ self .assertDictEqual (item , expected )
95+ value = Compose .execute (data , transforms )
8196 for item in value :
82- self .assertDictEqual (item , { "a" : 2 , "b" : 1 , "c" : 2 } )
97+ self .assertDictEqual (item , expected )
8398
8499 def test_non_dict_compose_with_unpack (self ):
85100 def a (i , i2 ):
@@ -88,8 +103,11 @@ def a(i, i2):
88103 def b (i , i2 ):
89104 return i + "b" , i2 + "b2"
90105
91- c = Compose ([a , b , a , b ], map_items = False , unpack_items = True )
92- self .assertEqual (c (("" , "" )), ("abab" , "a2b2a2b2" ))
106+ transforms = [a , b , a , b ]
107+ data = ("" , "" )
108+ expected = ("abab" , "a2b2a2b2" )
109+ self .assertEqual (Compose (transforms , map_items = False , unpack_items = True )(data ), expected )
110+ self .assertEqual (Compose .execute (data , transforms , map_items = False , unpack_items = True ), expected )
93111
94112 def test_list_non_dict_compose_with_unpack (self ):
95113 def a (i , i2 ):
@@ -98,8 +116,11 @@ def a(i, i2):
98116 def b (i , i2 ):
99117 return i + "b" , i2 + "b2"
100118
101- c = Compose ([a , b , a , b ], unpack_items = True )
102- self .assertEqual (c ([("" , "" ), ("t" , "t" )]), [("abab" , "a2b2a2b2" ), ("tabab" , "ta2b2a2b2" )])
119+ transforms = [a , b , a , b ]
120+ data = [("" , "" ), ("t" , "t" )]
121+ expected = [("abab" , "a2b2a2b2" ), ("tabab" , "ta2b2a2b2" )]
122+ self .assertEqual (Compose (transforms , unpack_items = True )(data ), expected )
123+ self .assertEqual (Compose .execute (data , transforms , unpack_items = True ), expected )
103124
104125 def test_list_dict_compose_no_map (self ):
105126 def a (d ): # transform to handle dict data
@@ -119,10 +140,16 @@ def c(d): # transform to handle dict data
119140 di ["c" ] += 1
120141 return d
121142
122- transforms = Compose ([a , a , b , c , c ], map_items = False )
123- value = transforms ({"a" : 0 , "b" : 0 , "c" : 0 })
143+ transforms = [a , a , b , c , c ]
144+ data = {"a" : 0 , "b" : 0 , "c" : 0 }
145+ expected = {"a" : 2 , "b" : 1 , "c" : 2 }
146+ value = Compose (transforms , map_items = False )(data )
124147 for item in value :
125- self .assertDictEqual (item , {"a" : 2 , "b" : 1 , "c" : 2 })
148+ self .assertDictEqual (item , expected )
149+ value = Compose .execute (data , transforms , map_items = False )
150+ for item in value :
151+ self .assertDictEqual (item , expected )
152+
126153
127154 def test_random_compose (self ):
128155 class _Acc (Randomizable ):
@@ -220,5 +247,106 @@ def test_backwards_compatible_imports(self):
220247 from monai .transforms .compose import MapTransform , RandomizableTransform , Transform # noqa: F401
221248
222249
250+ TEST_COMPOSE_EXECUTE_TEST_CASES = [
251+ [None , tuple ()],
252+ [None , (Rotate (np .pi / 8 ),)],
253+ [None , (Flip (0 ), Flip (1 ), Rotate90 (1 ), Zoom (0.8 ), NormalizeIntensity ())],
254+ [('a' ,), (Rotated (('a' ,), np .pi / 8 ),)],
255+ ]
256+
257+
258+ class TestComposeExecute (unittest .TestCase ):
259+
260+ @parameterized .expand (TEST_COMPOSE_EXECUTE_TEST_CASES )
261+ def test_compose_execute_equivalence (self , keys , pipeline ):
262+
263+ if keys is None :
264+ data = torch .unsqueeze (torch .tensor (np .arange (24 * 32 ).reshape (24 , 32 )), axis = 0 )
265+ else :
266+ data = {}
267+ for i_k , k in enumerate (keys ):
268+ data [k ] = torch .unsqueeze (torch .tensor (np .arange (24 * 32 )).reshape (24 , 32 ) + i_k * 768 ,
269+ axis = 0 )
270+
271+ expected = Compose (deepcopy (pipeline ))(data )
272+
273+ for cutoff in range (len (pipeline )):
274+
275+ c = Compose (deepcopy (pipeline ))
276+ actual = c (c (data , end = cutoff ), start = cutoff )
277+ if isinstance (actual , dict ):
278+ for k in actual .keys ():
279+ self .assertTrue (torch .allclose (expected [k ], actual [k ]))
280+ else :
281+ self .assertTrue (torch .allclose (expected , actual ))
282+
283+ p = deepcopy (pipeline )
284+ actual = Compose .execute (
285+ Compose .execute (data , p , start = 0 , end = cutoff ), p , start = cutoff )
286+ if isinstance (actual , dict ):
287+ for k in actual .keys ():
288+ self .assertTrue (torch .allclose (expected [k ], actual [k ]))
289+ else :
290+ self .assertTrue (torch .allclose (expected , actual ))
291+
292+
293+ class TestOps :
294+
295+ @staticmethod
296+ def concat (value ):
297+ def _inner (data ):
298+ return data + value
299+
300+ return _inner
301+
302+ @staticmethod
303+ def concatd (value ):
304+ def _inner (data ):
305+ return {k : v + value for k , v in data .items ()}
306+
307+ return _inner
308+
309+ @staticmethod
310+ def concata (value ):
311+ def _inner (data1 , data2 ):
312+ return data1 + value , data2 + value
313+
314+ return _inner
315+
316+
317+ TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES = [
318+ [{}, ("" ,), (TestOps .concat ('a' ), TestOps .concat ('b' ))],
319+ [{"unpack_items" : True }, ("x" , "y" ), (TestOps .concat ('a' ), TestOps .concat ('b' ))],
320+ [{"map_items" : False }, {"x" : "1" , "y" : "2" }, (TestOps .concatd ('a' ), TestOps .concatd ('b' ))],
321+ [{"unpack_items" : True , "map_items" : False }, ("x" , "y" ), (TestOps .concata ('a' ), TestOps .concata ('b' ))],
322+ ]
323+
324+
325+ class TestComposeExecuteWithFlags (unittest .TestCase ):
326+
327+ @parameterized .expand (TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES )
328+ def test_compose_execute_equivalence_with_flags (self , flags , data , pipeline ):
329+ expected = Compose (pipeline , ** flags )(data )
330+
331+ for cutoff in range (len (pipeline )):
332+
333+ c = Compose (deepcopy (pipeline ), ** flags )
334+ actual = c (c (data , end = cutoff ), start = cutoff )
335+ if isinstance (actual , dict ):
336+ for k in actual .keys ():
337+ self .assertEqual (expected [k ], actual [k ])
338+ else :
339+ self .assertTrue (expected , actual )
340+
341+ p = deepcopy (pipeline )
342+ actual = Compose .execute (
343+ Compose .execute (data , p , start = 0 , end = cutoff , ** flags ), p , start = cutoff , ** flags )
344+ if isinstance (actual , dict ):
345+ for k in actual .keys ():
346+ self .assertTrue (expected [k ], actual [k ])
347+ else :
348+ self .assertTrue (expected , actual )
349+
350+
223351if __name__ == "__main__" :
224352 unittest .main ()
0 commit comments