Skip to content

Commit 3a69940

Browse files
committed
step 1 based on PR #4922, adding a lazy transform
Signed-off-by: Wenqi Li <[email protected]>
1 parent f407fcc commit 3a69940

File tree

8 files changed

+93
-9
lines changed

8 files changed

+93
-9
lines changed

docs/source/transforms.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ Generic Interfaces
6363
.. autoclass:: OneOf
6464
:members:
6565

66+
`LazyTransform`
67+
^^^^^^^^^^^^^^^
68+
.. autoclass:: LazyTransform
69+
:members:
70+
6671
Vanilla Transforms
6772
------------------
6873

monai/data/meta_obj.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class MetaObj:
8282
def __init__(self):
8383
self._meta: dict = MetaObj.get_default_meta()
8484
self._applied_operations: list = MetaObj.get_default_applied_operations()
85+
self._pending_operations: list = MetaObj.get_default_applied_operations()
8586
self._is_batch: bool = False
8687

8788
@staticmethod
@@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None:
199200
def pop_applied_operation(self) -> Any:
200201
return self._applied_operations.pop()
201202

203+
@property
204+
def pending_operations(self) -> list[dict]:
205+
"""Get the pending operations. Defaults to ``[]``."""
206+
if hasattr(self, "_pending_operations"):
207+
return self._pending_operations
208+
return MetaObj.get_default_applied_operations()
209+
210+
def push_pending_operation(self, t: Any) -> None:
211+
self._pending_operations.append(t)
212+
213+
def pop_pending_operation(self) -> Any:
214+
return self._pending_operations.pop()
215+
202216
@property
203217
def is_batch(self) -> bool:
204218
"""Return whether object is part of batch or not."""

monai/data/meta_tensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from monai.data.meta_obj import MetaObj, get_track_meta
2424
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2525
from monai.utils import look_up_option
26-
from monai.utils.enums import MetaKeys, PostFix, SpaceKeys
26+
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
2727
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
2828

2929
__all__ = ["MetaTensor"]
@@ -445,6 +445,13 @@ def pixdim(self):
445445
return [affine_to_spacing(a) for a in self.affine]
446446
return affine_to_spacing(self.affine)
447447

448+
def peek_pending_shape(self):
449+
"""Get the currently expected spatial shape as if all the pending operations are executed."""
450+
return self.pending_operations[-1][LazyAttr.SHAPE] if self.pending_operations else self.array.shape[1:]
451+
452+
def peek_pending_affine(self):
453+
return self.pending_operations[-1][LazyAttr.AFFINE] if self.pending_operations else self.affine
454+
448455
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
449456
"""
450457
must be defined for deepcopy to work

monai/transforms/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,15 @@
449449
ZoomD,
450450
ZoomDict,
451451
)
452-
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
452+
from .transform import (
453+
LazyTransform,
454+
MapTransform,
455+
Randomizable,
456+
RandomizableTransform,
457+
ThreadUnsafe,
458+
Transform,
459+
apply_transform,
460+
)
453461
from .utility.array import (
454462
AddChannel,
455463
AddCoordinateChannels,

monai/transforms/spatial/array.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
3232
from monai.transforms.intensity.array import GaussianSmooth
3333
from monai.transforms.inverse import InvertibleTransform
34-
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
34+
from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform
3535
from monai.transforms.utils import (
3636
convert_pad_mode,
3737
create_control_grid,
@@ -48,6 +48,7 @@
4848
GridSampleMode,
4949
GridSamplePadMode,
5050
InterpolateMode,
51+
LazyAttr,
5152
NdimageMode,
5253
NumpyPadMode,
5354
SplineMode,
@@ -751,7 +752,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
751752
return data
752753

753754

754-
class Flip(InvertibleTransform):
755+
class Flip(InvertibleTransform, LazyTransform):
755756
"""
756757
Reverses the order of elements along the given spatial axis. Preserves shape.
757758
See `torch.flip` documentation for additional details:
@@ -771,14 +772,13 @@ class Flip(InvertibleTransform):
771772
def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
772773
self.spatial_axis = spatial_axis
773774

774-
def update_meta(self, img, shape, axes):
775+
def update_meta(self, affine, shape, axes):
775776
# shape and axes include the channel dim
776-
affine = img.affine
777777
mat = convert_to_dst_type(torch.eye(len(affine)), affine)[0]
778778
for axis in axes:
779779
sp = axis - 1
780780
mat[sp, sp], mat[sp, -1] = mat[sp, sp] * -1, shape[axis] - 1
781-
img.affine = affine @ mat
781+
return affine @ mat
782782

783783
def forward_image(self, img, axes) -> torch.Tensor:
784784
return torch.flip(img, axes)
@@ -790,9 +790,16 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
790790
"""
791791
img = convert_to_tensor(img, track_meta=get_track_meta())
792792
axes = map_spatial_axes(img.ndim, self.spatial_axis)
793+
if self.lazy_evaluation and isinstance(img, MetaTensor):
794+
spatial_chn_shape = [1, *convert_to_numpy(img.peek_pending_shape()).tolist()]
795+
affine = img.peek_pending_affine()
796+
lazy_affine = self.update_meta(affine, spatial_chn_shape, axes)
797+
img.push_pending_operation({LazyAttr.SHAPE: img.peek_pending_shape(), LazyAttr.AFFINE: lazy_affine})
798+
self.push_transform(img)
799+
return img
793800
out = self.forward_image(img, axes)
794801
if get_track_meta():
795-
self.update_meta(out, out.shape, axes)
802+
out.affine = self.update_meta(out.affine, out.shape, axes) # type: ignore
796803
self.push_transform(out)
797804
return out
798805

monai/transforms/transform.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@
2626
from monai.utils.enums import TransformBackends
2727
from monai.utils.misc import MONAIEnvVars
2828

29-
__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
29+
__all__ = [
30+
"ThreadUnsafe",
31+
"apply_transform",
32+
"LazyTransform",
33+
"Randomizable",
34+
"RandomizableTransform",
35+
"Transform",
36+
"MapTransform",
37+
]
3038

3139
ReturnType = TypeVar("ReturnType")
3240

@@ -131,6 +139,26 @@ class ThreadUnsafe:
131139
pass
132140

133141

142+
class LazyTransform:
143+
"""
144+
An interface to denote whether a transform can be applied lazily. It is designed as part of lazy resampling of
145+
multiple transforms. Classes inheriting this interface should be able to operate in two modes:
146+
147+
- ``set_lazy_eval(False)`` (eagerly evaluating), the transform should output the finalized transform
148+
results without any pending operations. Both primary data and metadata of the outputs should be up-to-date.
149+
- ``set_lazy_eval(True)`` (lazily evaluating), the transform should only execute necessary/lightweight/lossless
150+
metadata updates to track any pending operations. The goal is that, in a later stage, the pending operations
151+
can be grouped together and evaluated more efficiently and accurately -- each transforms when evaluated
152+
independently may cause some information losses.
153+
154+
"""
155+
156+
lazy_evaluation: bool = False
157+
158+
def set_lazy_eval(self, value: bool):
159+
self.lazy_evaluation = value
160+
161+
134162
class Randomizable(ThreadUnsafe):
135163
"""
136164
An interface for handling random state locally, currently based on a class

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
InterpolateMode,
3535
InverseKeys,
3636
JITMetadataKeys,
37+
LazyAttr,
3738
LossReduction,
3839
MetaKeys,
3940
Method,

monai/utils/enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"AlgoEnsembleKeys",
5555
"HoVerNetMode",
5656
"HoVerNetBranch",
57+
"LazyAttr",
5758
]
5859

5960

@@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum):
616617
HV = "horizontal_vertical"
617618
NP = "nucleus_prediction"
618619
NC = "type_prediction"
620+
621+
622+
class LazyAttr(StrEnum):
623+
"""
624+
MetaTensor with pending operations requires some key attributes tracked especially when the primary array
625+
is not up-to-date due to lazy evaluation.
626+
This class specifies the set of key attributes to be tracked for each MetaTensor.
627+
"""
628+
629+
SHAPE = "lazy_shape" # spatial shape
630+
AFFINE = "lazy_affine"
631+
PADDING_MODE = "lazy_padding_mode"
632+
INTERP_MODE = "lazy_interpolation_mode"

0 commit comments

Comments
 (0)