Skip to content

Commit 9e2d381

Browse files
authored
6777 fixes boolean indexing of batched metatensor (#6781)
Fixes #6777 ### Description - skip metadata handling if the index is a tensor variable - make batched tensor handling a separate internal function for clarity ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent c62fac1 commit 9e2d381

2 files changed

Lines changed: 55 additions & 45 deletions

File tree

monai/data/meta_tensor.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
202202
:py:func:`MetaTensor._copy_meta`).
203203
"""
204204
out = []
205-
metas = None
205+
metas = None # optional output metadicts for each of the return value in `rets`
206206
is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch"))
207207
for idx, ret in enumerate(rets):
208208
# if not `MetaTensor`, nothing to do.
@@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
219219
# the following is not implemented but the network arch may run into this case:
220220
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
221221
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
222-
223-
# If we have a batch of data, then we need to be careful if a slice of
224-
# the data is returned. Depending on how the data are indexed, we return
225-
# some or all of the metadata, and the return object may or may not be a
226-
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
227222
if is_batch:
228-
# if indexing e.g., `batch[0]`
229-
if func == torch.Tensor.__getitem__:
230-
batch_idx = args[1]
231-
if isinstance(batch_idx, Sequence):
232-
batch_idx = batch_idx[0]
233-
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
234-
# first element will be `slice(None, None, None)` and `Ellipsis`,
235-
# respectively. Don't need to do anything with the metadata.
236-
if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
237-
ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
238-
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
239-
try:
240-
ret_meta = list_data_collate(ret_meta)
241-
except (TypeError, ValueError, RuntimeError, IndexError) as e:
242-
raise ValueError(
243-
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
244-
"please convert it into a torch Tensor using `x.as_tensor()` or "
245-
"a numpy array using `x.array`."
246-
) from e
247-
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
248-
ret_meta.is_batch = False
249-
if hasattr(ret_meta, "__dict__"):
250-
ret.__dict__ = ret_meta.__dict__.copy()
251-
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
252-
# But we only want to split the batch if the `unbind` is along the 0th
253-
# dimension.
254-
elif func == torch.Tensor.unbind:
255-
if len(args) > 1:
256-
dim = args[1]
257-
elif "dim" in kwargs:
258-
dim = kwargs["dim"]
259-
else:
260-
dim = 0
261-
if dim == 0:
262-
if metas is None:
263-
metas = decollate_batch(args[0], detach=False)
264-
ret.__dict__ = metas[idx].__dict__.copy()
265-
ret.is_batch = False
266-
223+
ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
267224
out.append(ret)
268225
# if the input was a tuple, then return it as a tuple
269226
return tuple(out) if isinstance(rets, tuple) else out
270227

228+
@classmethod
229+
def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
230+
"""utility function to handle batched MetaTensors."""
231+
# If we have a batch of data, then we need to be careful if a slice of
232+
# the data is returned. Depending on how the data are indexed, we return
233+
# some or all of the metadata, and the return object may or may not be a
234+
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
235+
# if indexing e.g., `batch[0]`
236+
if func == torch.Tensor.__getitem__:
237+
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
238+
return ret
239+
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
240+
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
241+
# first element will be `slice(None, None, None)` and `Ellipsis`,
242+
# respectively. Don't need to do anything with the metadata.
243+
if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor):
244+
return ret
245+
dec_batch = decollate_batch(args[0], detach=False)
246+
ret_meta = dec_batch[batch_idx]
247+
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
248+
try:
249+
ret_meta = list_data_collate(ret_meta)
250+
except (TypeError, ValueError, RuntimeError, IndexError) as e:
251+
raise ValueError(
252+
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
253+
"please consider converting it into a torch Tensor using `x.as_tensor()` or "
254+
"a numpy array using `x.array`."
255+
) from e
256+
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
257+
ret_meta.is_batch = False
258+
if hasattr(ret_meta, "__dict__"):
259+
ret.__dict__ = ret_meta.__dict__.copy()
260+
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
261+
# But we only want to split the batch if the `unbind` is along the 0th dimension.
262+
elif func == torch.Tensor.unbind:
263+
if len(args) > 1:
264+
dim = args[1]
265+
elif "dim" in kwargs:
266+
dim = kwargs["dim"]
267+
else:
268+
dim = 0
269+
if dim == 0:
270+
if metas is None:
271+
metas = decollate_batch(args[0], detach=False)
272+
if hasattr(metas[idx], "__dict__"):
273+
ret.__dict__ = metas[idx].__dict__.copy()
274+
ret.is_batch = False
275+
return ret
276+
271277
@classmethod
272278
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
273279
"""Wraps all torch functions."""

tests/test_meta_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ def test_slicing(self):
413413
x.is_batch = True
414414
with self.assertRaises(ValueError):
415415
x[slice(0, 8)]
416+
x = MetaTensor(np.zeros((3, 3, 4)))
417+
x.is_batch = True
418+
self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4))
419+
self.assertEqual(x[[True, False, True]].shape, (2, 3, 4))
416420

417421
@parameterized.expand(DTYPES)
418422
@SkipIfBeforePyTorchVersion((1, 8))

0 commit comments

Comments
 (0)