@@ -106,13 +106,13 @@ class Pad(InvertibleTransform, LazyTransform):
106106 backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
107107
108108 def __init__ (
109- self , to_pad : list [tuple [int , int ]] | None = None , mode : str = PytorchPadMode .CONSTANT , ** kwargs
109+ self , to_pad : tuple [tuple [int , int ]] | None = None , mode : str = PytorchPadMode .CONSTANT , ** kwargs
110110 ) -> None :
111111 self .to_pad = to_pad
112112 self .mode = mode
113113 self .kwargs = kwargs
114114
115- def compute_pad_width (self , spatial_shape : Sequence [int ]) -> list [tuple [int , int ]]:
115+ def compute_pad_width (self , spatial_shape : Sequence [int ]) -> tuple [tuple [int , int ]]:
116116 """
117117 dynamically compute the pad width according to the spatial shape.
118118 the output is the amount of padding for all dimensions including the channel.
@@ -123,8 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
123123 """
124124 raise NotImplementedError (f"subclass { self .__class__ .__name__ } must implement this method." )
125125
126- def __call__ ( # type: ignore
127- self , img : torch .Tensor , to_pad : list [tuple [int , int ]] | None = None , mode : str | None = None , ** kwargs
126+ def __call__ ( # type: ignore[override]
127+ self , img : torch .Tensor , to_pad : tuple [tuple [int , int ]] | None = None , mode : str | None = None , ** kwargs
128128 ) -> torch .Tensor :
129129 """
130130 Args:
@@ -150,7 +150,7 @@ def __call__( # type: ignore
150150 kwargs_ .update (kwargs )
151151
152152 img_t = convert_to_tensor (data = img , track_meta = get_track_meta ())
153- return pad_func (img_t , to_pad_ , mode_ , self .get_transform_info (), kwargs_ ) # type: ignore
153+ return pad_func (img_t , to_pad_ , mode_ , self .get_transform_info (), kwargs_ )
154154
155155 def inverse (self , data : MetaTensor ) -> MetaTensor :
156156 transform = self .pop_transform (data )
@@ -200,7 +200,7 @@ def __init__(
200200 self .method : Method = look_up_option (method , Method )
201201 super ().__init__ (mode = mode , ** kwargs )
202202
203- def compute_pad_width (self , spatial_shape : Sequence [int ]) -> list [tuple [int , int ]]:
203+ def compute_pad_width (self , spatial_shape : Sequence [int ]) -> tuple [tuple [int , int ]]:
204204 """
205205 dynamically compute the pad width according to the spatial shape.
206206
@@ -213,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
213213 pad_width = []
214214 for i , sp_i in enumerate (spatial_size ):
215215 width = max (sp_i - spatial_shape [i ], 0 )
216- pad_width .append ((width // 2 , width - (width // 2 )))
216+ pad_width .append ((int ( width // 2 ), int ( width - (width // 2 ) )))
217217 else :
218- pad_width = [(0 , max (sp_i - spatial_shape [i ], 0 )) for i , sp_i in enumerate (spatial_size )]
219- return [(0 , 0 )] + pad_width
218+ pad_width = [(0 , int ( max (sp_i - spatial_shape [i ], 0 ) )) for i , sp_i in enumerate (spatial_size )]
219+ return tuple ( [(0 , 0 )] + pad_width ) # type: ignore
220220
221221
222222class BorderPad (Pad ):
@@ -249,24 +249,26 @@ def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMo
249249 self .spatial_border = spatial_border
250250 super ().__init__ (mode = mode , ** kwargs )
251251
252- def compute_pad_width (self , spatial_shape : Sequence [int ]) -> list [tuple [int , int ]]:
252+ def compute_pad_width (self , spatial_shape : Sequence [int ]) -> tuple [tuple [int , int ]]:
253253 spatial_border = ensure_tuple (self .spatial_border )
254254 if not all (isinstance (b , int ) for b in spatial_border ):
255255 raise ValueError (f"self.spatial_border must contain only ints, got { spatial_border } ." )
256256 spatial_border = tuple (max (0 , b ) for b in spatial_border )
257257
258258 if len (spatial_border ) == 1 :
259- data_pad_width = [(spatial_border [0 ], spatial_border [0 ]) for _ in spatial_shape ]
259+ data_pad_width = [(int ( spatial_border [0 ]), int ( spatial_border [0 ]) ) for _ in spatial_shape ]
260260 elif len (spatial_border ) == len (spatial_shape ):
261- data_pad_width = [(sp , sp ) for sp in spatial_border [: len (spatial_shape )]]
261+ data_pad_width = [(int ( sp ), int ( sp ) ) for sp in spatial_border [: len (spatial_shape )]]
262262 elif len (spatial_border ) == len (spatial_shape ) * 2 :
263- data_pad_width = [(spatial_border [2 * i ], spatial_border [2 * i + 1 ]) for i in range (len (spatial_shape ))]
263+ data_pad_width = [
264+ (int (spatial_border [2 * i ]), int (spatial_border [2 * i + 1 ])) for i in range (len (spatial_shape ))
265+ ]
264266 else :
265267 raise ValueError (
266268 f"Unsupported spatial_border length: { len (spatial_border )} , available options are "
267269 f"[1, len(spatial_shape)={ len (spatial_shape )} , 2*len(spatial_shape)={ 2 * len (spatial_shape )} ]."
268270 )
269- return [(0 , 0 )] + data_pad_width
271+ return tuple ( [(0 , 0 )] + data_pad_width ) # type: ignore
270272
271273
272274class DivisiblePad (Pad ):
@@ -301,7 +303,7 @@ def __init__(
301303 self .method : Method = Method (method )
302304 super ().__init__ (mode = mode , ** kwargs )
303305
304- def compute_pad_width (self , spatial_shape : Sequence [int ]) -> list [tuple [int , int ]]:
306+ def compute_pad_width (self , spatial_shape : Sequence [int ]) -> tuple [tuple [int , int ]]:
305307 new_size = compute_divisible_spatial_size (spatial_shape = spatial_shape , k = self .k )
306308 spatial_pad = SpatialPad (spatial_size = new_size , method = self .method )
307309 return spatial_pad .compute_pad_width (spatial_shape )
@@ -322,7 +324,7 @@ def compute_slices(
322324 roi_start : Sequence [int ] | NdarrayOrTensor | None = None ,
323325 roi_end : Sequence [int ] | NdarrayOrTensor | None = None ,
324326 roi_slices : Sequence [slice ] | None = None ,
325- ):
327+ ) -> tuple [ slice ] :
326328 """
327329 Compute the crop slices based on specified `center & size` or `start & end` or `slices`.
328330
@@ -340,8 +342,8 @@ def compute_slices(
340342
341343 if roi_slices :
342344 if not all (s .step is None or s .step == 1 for s in roi_slices ):
343- raise ValueError ("only slice steps of 1/None are currently supported" )
344- return list (roi_slices )
345+ raise ValueError (f "only slice steps of 1/None are currently supported, got { roi_slices } . " )
346+ return ensure_tuple (roi_slices ) # type: ignore
345347 else :
346348 if roi_center is not None and roi_size is not None :
347349 roi_center_t = convert_to_tensor (data = roi_center , dtype = torch .int16 , wrap_sequence = True , device = "cpu" )
@@ -363,11 +365,12 @@ def compute_slices(
363365 roi_end_t = torch .maximum (roi_end_t , roi_start_t )
364366 # convert to slices (accounting for 1d)
365367 if roi_start_t .numel () == 1 :
366- return [slice (int (roi_start_t .item ()), int (roi_end_t .item ()))]
367- else :
368- return [slice (int (s ), int (e )) for s , e in zip (roi_start_t .tolist (), roi_end_t .tolist ())]
368+ return ensure_tuple ([slice (int (roi_start_t .item ()), int (roi_end_t .item ()))]) # type: ignore
369+ return ensure_tuple ( # type: ignore
370+ [slice (int (s ), int (e )) for s , e in zip (roi_start_t .tolist (), roi_end_t .tolist ())]
371+ )
369372
370- def __call__ (self , img : torch .Tensor , slices : tuple [slice , ...]) -> torch .Tensor : # type: ignore
373+ def __call__ (self , img : torch .Tensor , slices : tuple [slice , ...]) -> torch .Tensor : # type: ignore[override]
371374 """
372375 Apply the transform to `img`, assuming `img` is channel-first and
373376 slicing doesn't apply to the channel dim.
@@ -378,10 +381,10 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor
378381 if len (slices_ ) < sd :
379382 slices_ += [slice (None )] * (sd - len (slices_ ))
380383 # Add in the channel (no cropping)
381- slices = tuple ([slice (None )] + slices_ [:sd ])
384+ slices_ = list ([slice (None )] + slices_ [:sd ])
382385
383386 img_t : MetaTensor = convert_to_tensor (data = img , track_meta = get_track_meta ())
384- return crop_func (img_t , slices , self .get_transform_info ()) # type: ignore
387+ return crop_func (img_t , tuple ( slices_ ) , self .get_transform_info ())
385388
386389 def inverse (self , img : MetaTensor ) -> MetaTensor :
387390 transform = self .pop_transform (img )
@@ -429,13 +432,13 @@ def __init__(
429432 roi_center = roi_center , roi_size = roi_size , roi_start = roi_start , roi_end = roi_end , roi_slices = roi_slices
430433 )
431434
432- def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore
435+ def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
433436 """
434437 Apply the transform to `img`, assuming `img` is channel-first and
435438 slicing doesn't apply to the channel dim.
436439
437440 """
438- return super ().__call__ (img = img , slices = self .slices )
441+ return super ().__call__ (img = img , slices = ensure_tuple ( self .slices ) )
439442
440443
441444class CenterSpatialCrop (Crop ):
@@ -456,12 +459,12 @@ class CenterSpatialCrop(Crop):
456459 def __init__ (self , roi_size : Sequence [int ] | int ) -> None :
457460 self .roi_size = roi_size
458461
459- def compute_slices (self , spatial_size : Sequence [int ]): # type: ignore
462+ def compute_slices (self , spatial_size : Sequence [int ]) -> tuple [ slice ] : # type: ignore[override]
460463 roi_size = fall_back_tuple (self .roi_size , spatial_size )
461464 roi_center = [i // 2 for i in spatial_size ]
462465 return super ().compute_slices (roi_center = roi_center , roi_size = roi_size )
463466
464- def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore
467+ def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
465468 """
466469 Apply the transform to `img`, assuming `img` is channel-first and
467470 slicing doesn't apply to the channel dim.
@@ -486,7 +489,7 @@ class CenterScaleCrop(Crop):
486489 def __init__ (self , roi_scale : Sequence [float ] | float ):
487490 self .roi_scale = roi_scale
488491
489- def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore
492+ def __call__ (self , img : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
490493 img_size = img .peek_pending_shape () if isinstance (img , MetaTensor ) else img .shape [1 :]
491494 ndim = len (img_size )
492495 roi_size = [ceil (r * s ) for r , s in zip (ensure_tuple_rep (self .roi_scale , ndim ), img_size )]
@@ -771,7 +774,7 @@ def lazy_evaluation(self, _val: bool):
771774 self ._lazy_evaluation = _val
772775 self .padder .lazy_evaluation = _val
773776
774- def compute_bounding_box (self , img : torch .Tensor ):
777+ def compute_bounding_box (self , img : torch .Tensor ) -> tuple [ np . ndarray , np . ndarray ] :
775778 """
776779 Compute the start points and end points of bounding box to crop.
777780 And adjust bounding box coords to be divisible by `k`.
@@ -794,7 +797,7 @@ def compute_bounding_box(self, img: torch.Tensor):
794797
795798 def crop_pad (
796799 self , img : torch .Tensor , box_start : np .ndarray , box_end : np .ndarray , mode : str | None = None , ** pad_kwargs
797- ):
800+ ) -> torch . Tensor :
798801 """
799802 Crop and pad based on the bounding box.
800803
@@ -817,7 +820,9 @@ def crop_pad(
817820 ret .applied_operations [- 1 ][TraceKeys .EXTRA_INFO ]["pad_info" ] = ret .applied_operations .pop ()
818821 return ret
819822
820- def __call__ (self , img : torch .Tensor , mode : str | None = None , ** pad_kwargs ): # type: ignore
823+ def __call__ ( # type: ignore[override]
824+ self , img : torch .Tensor , mode : str | None = None , ** pad_kwargs
825+ ) -> torch .Tensor :
821826 """
822827 Apply the transform to `img`, assuming `img` is channel-first and
823828 slicing doesn't change the channel dim.
@@ -826,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): #
826831 cropped = self .crop_pad (img , box_start , box_end , mode , ** pad_kwargs )
827832
828833 if self .return_coords :
829- return cropped , box_start , box_end
834+ return cropped , box_start , box_end # type: ignore[return-value]
830835 return cropped
831836
832837 def inverse (self , img : MetaTensor ) -> MetaTensor :
@@ -995,7 +1000,7 @@ def __init__(
9951000 self .num_samples = num_samples
9961001 self .image = image
9971002 self .image_threshold = image_threshold
998- self .centers : list [ list [ int ] ] | None = None
1003+ self .centers : tuple [ tuple ] | None = None
9991004 self .fg_indices = fg_indices
10001005 self .bg_indices = bg_indices
10011006 self .allow_smaller = allow_smaller
@@ -1173,7 +1178,7 @@ def __init__(
11731178 self .num_samples = num_samples
11741179 self .image = image
11751180 self .image_threshold = image_threshold
1176- self .centers : list [ list [ int ] ] | None = None
1181+ self .centers : tuple [ tuple ] | None = None
11771182 self .indices = indices
11781183 self .allow_smaller = allow_smaller
11791184 self .warn = warn
0 commit comments