Skip to content

Commit a8302ec

Browse files
authored
6109 no mutate ratio /user inputs croppad (#6127)
Fixes #6109 ### Description - use tuples for user inputs to avoid changes - enhance the type checks - fixes issue of `ratios` in `RandCropByLabelClasses ` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent f754928 commit a8302ec

11 files changed

Lines changed: 77 additions & 70 deletions

monai/apps/detection/transforms/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def __init__(
10711071
if len(self.image_keys) != len(self.meta_keys):
10721072
raise ValueError("meta_keys should have the same length as keys.")
10731073
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys))
1074-
self.centers: list[list[int]] | None = None
1074+
self.centers: tuple[tuple] | None = None
10751075
self.allow_smaller = allow_smaller
10761076

10771077
def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray:

monai/apps/detection/utils/detector_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def pad_images(
149149
max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible)
150150

151151
# allocate memory for the padded images
152-
images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device)
152+
images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device)
153153

154154
# Use `SpatialPad` to match sizes, padding in the end will not affect boxes
155155
padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs)

monai/transforms/croppad/array.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ class Pad(InvertibleTransform, LazyTransform):
106106
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
107107

108108
def __init__(
109-
self, to_pad: list[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
109+
self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
110110
) -> None:
111111
self.to_pad = to_pad
112112
self.mode = mode
113113
self.kwargs = kwargs
114114

115-
def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
115+
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
116116
"""
117117
dynamically compute the pad width according to the spatial shape.
118118
the output is the amount of padding for all dimensions including the channel.
@@ -123,8 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
123123
"""
124124
raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.")
125125

126-
def __call__( # type: ignore
127-
self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
126+
def __call__( # type: ignore[override]
127+
self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
128128
) -> torch.Tensor:
129129
"""
130130
Args:
@@ -150,7 +150,7 @@ def __call__( # type: ignore
150150
kwargs_.update(kwargs)
151151

152152
img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
153-
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) # type: ignore
153+
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_)
154154

155155
def inverse(self, data: MetaTensor) -> MetaTensor:
156156
transform = self.pop_transform(data)
@@ -200,7 +200,7 @@ def __init__(
200200
self.method: Method = look_up_option(method, Method)
201201
super().__init__(mode=mode, **kwargs)
202202

203-
def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
203+
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
204204
"""
205205
dynamically compute the pad width according to the spatial shape.
206206
@@ -213,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
213213
pad_width = []
214214
for i, sp_i in enumerate(spatial_size):
215215
width = max(sp_i - spatial_shape[i], 0)
216-
pad_width.append((width // 2, width - (width // 2)))
216+
pad_width.append((int(width // 2), int(width - (width // 2))))
217217
else:
218-
pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]
219-
return [(0, 0)] + pad_width
218+
pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)]
219+
return tuple([(0, 0)] + pad_width) # type: ignore
220220

221221

222222
class BorderPad(Pad):
@@ -249,24 +249,26 @@ def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMo
249249
self.spatial_border = spatial_border
250250
super().__init__(mode=mode, **kwargs)
251251

252-
def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
252+
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
253253
spatial_border = ensure_tuple(self.spatial_border)
254254
if not all(isinstance(b, int) for b in spatial_border):
255255
raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.")
256256
spatial_border = tuple(max(0, b) for b in spatial_border)
257257

258258
if len(spatial_border) == 1:
259-
data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in spatial_shape]
259+
data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape]
260260
elif len(spatial_border) == len(spatial_shape):
261-
data_pad_width = [(sp, sp) for sp in spatial_border[: len(spatial_shape)]]
261+
data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]]
262262
elif len(spatial_border) == len(spatial_shape) * 2:
263-
data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))]
263+
data_pad_width = [
264+
(int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape))
265+
]
264266
else:
265267
raise ValueError(
266268
f"Unsupported spatial_border length: {len(spatial_border)}, available options are "
267269
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
268270
)
269-
return [(0, 0)] + data_pad_width
271+
return tuple([(0, 0)] + data_pad_width) # type: ignore
270272

271273

272274
class DivisiblePad(Pad):
@@ -301,7 +303,7 @@ def __init__(
301303
self.method: Method = Method(method)
302304
super().__init__(mode=mode, **kwargs)
303305

304-
def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
306+
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
305307
new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k)
306308
spatial_pad = SpatialPad(spatial_size=new_size, method=self.method)
307309
return spatial_pad.compute_pad_width(spatial_shape)
@@ -322,7 +324,7 @@ def compute_slices(
322324
roi_start: Sequence[int] | NdarrayOrTensor | None = None,
323325
roi_end: Sequence[int] | NdarrayOrTensor | None = None,
324326
roi_slices: Sequence[slice] | None = None,
325-
):
327+
) -> tuple[slice]:
326328
"""
327329
Compute the crop slices based on specified `center & size` or `start & end` or `slices`.
328330
@@ -340,8 +342,8 @@ def compute_slices(
340342

341343
if roi_slices:
342344
if not all(s.step is None or s.step == 1 for s in roi_slices):
343-
raise ValueError("only slice steps of 1/None are currently supported")
344-
return list(roi_slices)
345+
raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.")
346+
return ensure_tuple(roi_slices) # type: ignore
345347
else:
346348
if roi_center is not None and roi_size is not None:
347349
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
@@ -363,11 +365,12 @@ def compute_slices(
363365
roi_end_t = torch.maximum(roi_end_t, roi_start_t)
364366
# convert to slices (accounting for 1d)
365367
if roi_start_t.numel() == 1:
366-
return [slice(int(roi_start_t.item()), int(roi_end_t.item()))]
367-
else:
368-
return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
368+
return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore
369+
return ensure_tuple( # type: ignore
370+
[slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
371+
)
369372

370-
def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore
373+
def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override]
371374
"""
372375
Apply the transform to `img`, assuming `img` is channel-first and
373376
slicing doesn't apply to the channel dim.
@@ -378,10 +381,10 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor
378381
if len(slices_) < sd:
379382
slices_ += [slice(None)] * (sd - len(slices_))
380383
# Add in the channel (no cropping)
381-
slices = tuple([slice(None)] + slices_[:sd])
384+
slices_ = list([slice(None)] + slices_[:sd])
382385

383386
img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta())
384-
return crop_func(img_t, slices, self.get_transform_info()) # type: ignore
387+
return crop_func(img_t, tuple(slices_), self.get_transform_info())
385388

386389
def inverse(self, img: MetaTensor) -> MetaTensor:
387390
transform = self.pop_transform(img)
@@ -429,13 +432,13 @@ def __init__(
429432
roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices
430433
)
431434

432-
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
435+
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
433436
"""
434437
Apply the transform to `img`, assuming `img` is channel-first and
435438
slicing doesn't apply to the channel dim.
436439
437440
"""
438-
return super().__call__(img=img, slices=self.slices)
441+
return super().__call__(img=img, slices=ensure_tuple(self.slices))
439442

440443

441444
class CenterSpatialCrop(Crop):
@@ -456,12 +459,12 @@ class CenterSpatialCrop(Crop):
456459
def __init__(self, roi_size: Sequence[int] | int) -> None:
457460
self.roi_size = roi_size
458461

459-
def compute_slices(self, spatial_size: Sequence[int]): # type: ignore
462+
def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override]
460463
roi_size = fall_back_tuple(self.roi_size, spatial_size)
461464
roi_center = [i // 2 for i in spatial_size]
462465
return super().compute_slices(roi_center=roi_center, roi_size=roi_size)
463466

464-
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
467+
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
465468
"""
466469
Apply the transform to `img`, assuming `img` is channel-first and
467470
slicing doesn't apply to the channel dim.
@@ -486,7 +489,7 @@ class CenterScaleCrop(Crop):
486489
def __init__(self, roi_scale: Sequence[float] | float):
487490
self.roi_scale = roi_scale
488491

489-
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
492+
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
490493
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
491494
ndim = len(img_size)
492495
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
@@ -771,7 +774,7 @@ def lazy_evaluation(self, _val: bool):
771774
self._lazy_evaluation = _val
772775
self.padder.lazy_evaluation = _val
773776

774-
def compute_bounding_box(self, img: torch.Tensor):
777+
def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
775778
"""
776779
Compute the start points and end points of bounding box to crop.
777780
And adjust bounding box coords to be divisible by `k`.
@@ -794,7 +797,7 @@ def compute_bounding_box(self, img: torch.Tensor):
794797

795798
def crop_pad(
796799
self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs
797-
):
800+
) -> torch.Tensor:
798801
"""
799802
Crop and pad based on the bounding box.
800803
@@ -817,7 +820,9 @@ def crop_pad(
817820
ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop()
818821
return ret
819822

820-
def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): # type: ignore
823+
def __call__( # type: ignore[override]
824+
self, img: torch.Tensor, mode: str | None = None, **pad_kwargs
825+
) -> torch.Tensor:
821826
"""
822827
Apply the transform to `img`, assuming `img` is channel-first and
823828
slicing doesn't change the channel dim.
@@ -826,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): #
826831
cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs)
827832

828833
if self.return_coords:
829-
return cropped, box_start, box_end
834+
return cropped, box_start, box_end # type: ignore[return-value]
830835
return cropped
831836

832837
def inverse(self, img: MetaTensor) -> MetaTensor:
@@ -995,7 +1000,7 @@ def __init__(
9951000
self.num_samples = num_samples
9961001
self.image = image
9971002
self.image_threshold = image_threshold
998-
self.centers: list[list[int]] | None = None
1003+
self.centers: tuple[tuple] | None = None
9991004
self.fg_indices = fg_indices
10001005
self.bg_indices = bg_indices
10011006
self.allow_smaller = allow_smaller
@@ -1173,7 +1178,7 @@ def __init__(
11731178
self.num_samples = num_samples
11741179
self.image = image
11751180
self.image_threshold = image_threshold
1176-
self.centers: list[list[int]] | None = None
1181+
self.centers: tuple[tuple] | None = None
11771182
self.indices = indices
11781183
self.allow_smaller = allow_smaller
11791184
self.warn = warn

monai/transforms/croppad/dictionary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
698698
self.cropper: CropForeground
699699
box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])
700700
if self.start_coord_key is not None:
701-
d[self.start_coord_key] = box_start
701+
d[self.start_coord_key] = box_start # type: ignore
702702
if self.end_coord_key is not None:
703-
d[self.end_coord_key] = box_end
703+
d[self.end_coord_key] = box_end # type: ignore
704704
for key, m in self.key_iterator(d, self.mode):
705705
d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
706706
return d

monai/transforms/croppad/functional.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
147147
return img
148148

149149

150-
def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transform_info: dict, kwargs):
150+
def pad_func(
151+
img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs
152+
) -> torch.Tensor:
151153
"""
152154
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
153155
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
@@ -166,17 +168,17 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
166168
kwargs: other arguments for the `np.pad` or `torch.pad` function.
167169
note that `np.pad` treats channel dimension as the first dimension.
168170
"""
169-
extra_info = {"padded": to_pad, "mode": str(mode)}
171+
extra_info = {"padded": to_pad, "mode": f"{mode}"}
170172
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
171173
spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3
172174
do_pad = np.asarray(to_pad).any()
173175
if do_pad:
174-
to_pad = list(to_pad)
175-
if len(to_pad) < len(img.shape):
176-
to_pad = list(to_pad) + [(0, 0)] * (len(img.shape) - len(to_pad))
177-
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
176+
to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad]
177+
if len(to_pad_list) < len(img.shape):
178+
to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list))
179+
to_shift = [-s[0] for s in to_pad_list[1:]] # skipping the channel pad
178180
xform = create_translate(spatial_rank, to_shift)
179-
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad[1:])]
181+
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])]
180182
else:
181183
shape = img_size
182184
xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64)
@@ -191,13 +193,13 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
191193
)
192194
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
193195
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
194-
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
195-
out = pad_nd(out, to_pad, mode, **kwargs) if do_pad else out
196+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
197+
out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out
196198
out = convert_to_tensor(out, track_meta=get_track_meta())
197-
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
199+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
198200

199201

200-
def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict):
202+
def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor:
201203
"""
202204
Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according
203205
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
@@ -229,6 +231,6 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict
229231
)
230232
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
231233
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
232-
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
234+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
233235
out = out[slices]
234-
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
236+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore

monai/transforms/inverse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def track_transform_meta(
190190
return data
191191
return out_obj # return with data_t as tensor if get_track_meta() is False
192192

193-
info = transform_info
193+
info = transform_info.copy()
194194
# track the current spatial shape
195195
if orig_size is not None:
196196
info[TraceKeys.ORIG_SIZE] = orig_size

0 commit comments

Comments
 (0)