3131from monai .transforms .croppad .array import CenterSpatialCrop , ResizeWithPadOrCrop
3232from monai .transforms .intensity .array import GaussianSmooth
3333from monai .transforms .inverse import InvertibleTransform
34- from monai .transforms .transform import Randomizable , RandomizableTransform , Transform
34+ from monai .transforms .transform import LazyTransform , Randomizable , RandomizableTransform , Transform
3535from monai .transforms .utils import (
3636 convert_pad_mode ,
3737 create_control_grid ,
4848 GridSampleMode ,
4949 GridSamplePadMode ,
5050 InterpolateMode ,
51+ LazyAttr ,
5152 NdimageMode ,
5253 NumpyPadMode ,
5354 SplineMode ,
@@ -751,7 +752,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
751752 return data
752753
753754
754- class Flip (InvertibleTransform ):
755+ class Flip (InvertibleTransform , LazyTransform ):
755756 """
756757 Reverses the order of elements along the given spatial axis. Preserves shape.
757758 See `torch.flip` documentation for additional details:
@@ -771,14 +772,13 @@ class Flip(InvertibleTransform):
771772 def __init__ (self , spatial_axis : Optional [Union [Sequence [int ], int ]] = None ) -> None :
772773 self .spatial_axis = spatial_axis
773774
774- def update_meta (self , img , shape , axes ):
775+ def update_meta (self , affine , shape , axes ):
775776 # shape and axes include the channel dim
776- affine = img .affine
777777 mat = convert_to_dst_type (torch .eye (len (affine )), affine )[0 ]
778778 for axis in axes :
779779 sp = axis - 1
780780 mat [sp , sp ], mat [sp , - 1 ] = mat [sp , sp ] * - 1 , shape [axis ] - 1
781- img . affine = affine @ mat
781+ return affine @ mat
782782
783783 def forward_image (self , img , axes ) -> torch .Tensor :
784784 return torch .flip (img , axes )
@@ -790,9 +790,16 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
790790 """
791791 img = convert_to_tensor (img , track_meta = get_track_meta ())
792792 axes = map_spatial_axes (img .ndim , self .spatial_axis )
793+ if self .lazy_evaluation and isinstance (img , MetaTensor ):
794+ spatial_chn_shape = [1 , * convert_to_numpy (img .peek_pending_shape ()).tolist ()]
795+ affine = img .peek_pending_affine ()
796+ lazy_affine = self .update_meta (affine , spatial_chn_shape , axes )
797+ img .push_pending_operation ({LazyAttr .SHAPE : img .peek_pending_shape (), LazyAttr .AFFINE : lazy_affine })
798+ self .push_transform (img )
799+ return img
793800 out = self .forward_image (img , axes )
794801 if get_track_meta ():
795- self .update_meta (out , out .shape , axes )
802+ out . affine = self .update_meta (out . affine , out .shape , axes ) # type: ignore
796803 self .push_transform (out )
797804 return out
798805
0 commit comments