Skip to content

Commit 61f67e4

Browse files
committed
Added tests; updated documentation
1 parent c5a73f6 commit 61f67e4

2 files changed

Lines changed: 149 additions & 19 deletions

File tree

monai/transforms/compose.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,22 @@ def execute(
197197
``execute`` provides the implementation that Compose uses to execute a sequence
198198
of transforms. As well as being used by Compose, it can be used by subclasses of
199199
Compose and by code that doesn't have a Compose instance but needs to execute a
200-
sequence of transforms is if it were executed by Compose. For the most part, it
201-
is recommended to use Compose instances, however.
200+
sequence of transforms is if it were executed by Compose. It should only be used directly
201+
when it is not possible to use ``Compose.__call__`` to achieve the same goal.
202202
Args:
203203
input_: a tensor-like object to be transformed
204204
transforms: a sequence of transforms to be carried out
205-
map_items: whether to apply the transform to each item in ``data```.
205+
map_items: whether to apply the transform to each item in ``data``.
206206
Defaults to True if not set.
207207
unpack_items: whether to unpack parameters using '*'. Defaults to False if not set
208208
log_stats: whether to log detailed information about the application of ``transforms``
209209
to ``input_``. For NumPy ndarrays and PyTorch tensors, log only the data shape and
210210
value range. Defaults to False if not set.
211-
start:
212-
end:
213-
threading:
211+
start: the index of the first transform to be executed. If not set, this defaults to 0
212+
end: the index after the last transform to be exectued. If set, the transform at index-1
213+
is the last transform that is executed. If this is not set, it defaults to len(transforms)
214+
threading: whether executing is happening in a threaded environment. If set, copies are made
215+
of transforms that have the ``RandomizedTrait`` interface.
214216
215217
Returns:
216218

tests/test_compose.py

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313

1414
import sys
1515
import unittest
16+
from copy import deepcopy
17+
18+
from parameterized import parameterized
19+
20+
import numpy as np
21+
import torch
1622

1723
from monai.data import DataLoader, Dataset
18-
from monai.transforms import AddChannel, Compose
24+
from monai.transforms import AddChannel, Compose, Flip, Rotate90, Zoom, NormalizeIntensity, Rotate, Rotated
1925
from monai.transforms.transform import Randomizable
2026
from monai.utils import set_determinism
2127

@@ -56,8 +62,12 @@ def b(d):
5662
d["b"] += 1
5763
return d
5864

59-
c = Compose([a, b, a, b, a])
60-
self.assertDictEqual(c({"a": 0, "b": 0}), {"a": 3, "b": 2})
65+
transforms = [a, b, a, b, a]
66+
data = {"a": 0, "b": 0}
67+
expected = {"a": 3, "b": 2}
68+
69+
self.assertDictEqual(Compose(transforms)(data), expected)
70+
self.assertDictEqual(Compose.execute(data, transforms), expected)
6171

6272
def test_list_dict_compose(self):
6373
def a(d): # transform to handle dict data
@@ -76,10 +86,15 @@ def c(d): # transform to handle dict data
7686
d["c"] += 1
7787
return d
7888

79-
transforms = Compose([a, a, b, c, c])
80-
value = transforms({"a": 0, "b": 0, "c": 0})
89+
transforms = [a, a, b, c, c]
90+
data = {"a": 0, "b": 0, "c": 0}
91+
expected = {"a": 2, "b": 1, "c": 2}
92+
value = Compose(transforms)(data)
93+
for item in value:
94+
self.assertDictEqual(item, expected)
95+
value = Compose.execute(data, transforms)
8196
for item in value:
82-
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})
97+
self.assertDictEqual(item, expected)
8398

8499
def test_non_dict_compose_with_unpack(self):
85100
def a(i, i2):
@@ -88,8 +103,11 @@ def a(i, i2):
88103
def b(i, i2):
89104
return i + "b", i2 + "b2"
90105

91-
c = Compose([a, b, a, b], map_items=False, unpack_items=True)
92-
self.assertEqual(c(("", "")), ("abab", "a2b2a2b2"))
106+
transforms = [a, b, a, b]
107+
data = ("", "")
108+
expected = ("abab", "a2b2a2b2")
109+
self.assertEqual(Compose(transforms, map_items=False, unpack_items=True)(data), expected)
110+
self.assertEqual(Compose.execute(data, transforms, map_items=False, unpack_items=True), expected)
93111

94112
def test_list_non_dict_compose_with_unpack(self):
95113
def a(i, i2):
@@ -98,8 +116,11 @@ def a(i, i2):
98116
def b(i, i2):
99117
return i + "b", i2 + "b2"
100118

101-
c = Compose([a, b, a, b], unpack_items=True)
102-
self.assertEqual(c([("", ""), ("t", "t")]), [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")])
119+
transforms = [a, b, a, b]
120+
data = [("", ""), ("t", "t")]
121+
expected = [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")]
122+
self.assertEqual(Compose(transforms, unpack_items=True)(data), expected)
123+
self.assertEqual(Compose.execute(data, transforms, unpack_items=True), expected)
103124

104125
def test_list_dict_compose_no_map(self):
105126
def a(d): # transform to handle dict data
@@ -119,10 +140,16 @@ def c(d): # transform to handle dict data
119140
di["c"] += 1
120141
return d
121142

122-
transforms = Compose([a, a, b, c, c], map_items=False)
123-
value = transforms({"a": 0, "b": 0, "c": 0})
143+
transforms = [a, a, b, c, c]
144+
data = {"a": 0, "b": 0, "c": 0}
145+
expected = {"a": 2, "b": 1, "c": 2}
146+
value = Compose(transforms, map_items=False)(data)
124147
for item in value:
125-
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})
148+
self.assertDictEqual(item, expected)
149+
value = Compose.execute(data, transforms, map_items=False)
150+
for item in value:
151+
self.assertDictEqual(item, expected)
152+
126153

127154
def test_random_compose(self):
128155
class _Acc(Randomizable):
@@ -220,5 +247,106 @@ def test_backwards_compatible_imports(self):
220247
from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401
221248

222249

250+
TEST_COMPOSE_EXECUTE_TEST_CASES = [
251+
[None, tuple()],
252+
[None, (Rotate(np.pi/8),)],
253+
[None, (Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())],
254+
[('a',), (Rotated(('a',), np.pi/8),)],
255+
]
256+
257+
258+
class TestComposeExecute(unittest.TestCase):
259+
260+
@parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)
261+
def test_compose_execute_equivalence(self, keys, pipeline):
262+
263+
if keys is None:
264+
data = torch.unsqueeze(torch.tensor(np.arange(24*32).reshape(24, 32)), axis=0)
265+
else:
266+
data = {}
267+
for i_k, k in enumerate(keys):
268+
data[k] = torch.unsqueeze(torch.tensor(np.arange(24*32)).reshape(24, 32) + i_k * 768,
269+
axis=0)
270+
271+
expected = Compose(deepcopy(pipeline))(data)
272+
273+
for cutoff in range(len(pipeline)):
274+
275+
c = Compose(deepcopy(pipeline))
276+
actual = c(c(data, end=cutoff), start=cutoff)
277+
if isinstance(actual, dict):
278+
for k in actual.keys():
279+
self.assertTrue(torch.allclose(expected[k], actual[k]))
280+
else:
281+
self.assertTrue(torch.allclose(expected, actual))
282+
283+
p = deepcopy(pipeline)
284+
actual = Compose.execute(
285+
Compose.execute(data, p, start=0, end=cutoff), p, start=cutoff)
286+
if isinstance(actual, dict):
287+
for k in actual.keys():
288+
self.assertTrue(torch.allclose(expected[k], actual[k]))
289+
else:
290+
self.assertTrue(torch.allclose(expected, actual))
291+
292+
293+
class TestOps:
294+
295+
@staticmethod
296+
def concat(value):
297+
def _inner(data):
298+
return data + value
299+
300+
return _inner
301+
302+
@staticmethod
303+
def concatd(value):
304+
def _inner(data):
305+
return {k: v + value for k, v in data.items()}
306+
307+
return _inner
308+
309+
@staticmethod
310+
def concata(value):
311+
def _inner(data1, data2):
312+
return data1 + value, data2 + value
313+
314+
return _inner
315+
316+
317+
TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES = [
318+
[{}, ("",), (TestOps.concat('a'), TestOps.concat('b'))],
319+
[{"unpack_items": True}, ("x", "y"), (TestOps.concat('a'), TestOps.concat('b'))],
320+
[{"map_items": False}, {"x": "1", "y": "2"}, (TestOps.concatd('a'), TestOps.concatd('b'))],
321+
[{"unpack_items": True, "map_items": False}, ("x", "y"), (TestOps.concata('a'), TestOps.concata('b'))],
322+
]
323+
324+
325+
class TestComposeExecuteWithFlags(unittest.TestCase):
326+
327+
@parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES)
328+
def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
329+
expected = Compose(pipeline, **flags)(data)
330+
331+
for cutoff in range(len(pipeline)):
332+
333+
c = Compose(deepcopy(pipeline), **flags)
334+
actual = c(c(data, end=cutoff), start=cutoff)
335+
if isinstance(actual, dict):
336+
for k in actual.keys():
337+
self.assertEqual(expected[k], actual[k])
338+
else:
339+
self.assertTrue(expected, actual)
340+
341+
p = deepcopy(pipeline)
342+
actual = Compose.execute(
343+
Compose.execute(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags)
344+
if isinstance(actual, dict):
345+
for k in actual.keys():
346+
self.assertTrue(expected[k], actual[k])
347+
else:
348+
self.assertTrue(expected, actual)
349+
350+
223351
if __name__ == "__main__":
224352
unittest.main()

0 commit comments

Comments
 (0)