@@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
202202 :py:func:`MetaTensor._copy_meta`).
203203 """
204204 out = []
205- metas = None
205+ metas = None # optional output metadicts for each of the return value in `rets`
206206 is_batch = any (x .is_batch for x in MetaObj .flatten_meta_objs (args , kwargs .values ()) if hasattr (x , "is_batch" ))
207207 for idx , ret in enumerate (rets ):
208208 # if not `MetaTensor`, nothing to do.
@@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
219219 # the following is not implemented but the network arch may run into this case:
220220 # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
221221 # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
222-
223- # If we have a batch of data, then we need to be careful if a slice of
224- # the data is returned. Depending on how the data are indexed, we return
225- # some or all of the metadata, and the return object may or may not be a
226- # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
227222 if is_batch :
228- # if indexing e.g., `batch[0]`
229- if func == torch .Tensor .__getitem__ :
230- batch_idx = args [1 ]
231- if isinstance (batch_idx , Sequence ):
232- batch_idx = batch_idx [0 ]
233- # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
234- # first element will be `slice(None, None, None)` and `Ellipsis`,
235- # respectively. Don't need to do anything with the metadata.
236- if batch_idx not in (slice (None , None , None ), Ellipsis , None ) and idx == 0 :
237- ret_meta = decollate_batch (args [0 ], detach = False )[batch_idx ]
238- if isinstance (ret_meta , list ) and ret_meta : # e.g. batch[0:2], re-collate
239- try :
240- ret_meta = list_data_collate (ret_meta )
241- except (TypeError , ValueError , RuntimeError , IndexError ) as e :
242- raise ValueError (
243- "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
244- "please convert it into a torch Tensor using `x.as_tensor()` or "
245- "a numpy array using `x.array`."
246- ) from e
247- elif isinstance (ret_meta , MetaObj ): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
248- ret_meta .is_batch = False
249- if hasattr (ret_meta , "__dict__" ):
250- ret .__dict__ = ret_meta .__dict__ .copy ()
251- # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
252- # But we only want to split the batch if the `unbind` is along the 0th
253- # dimension.
254- elif func == torch .Tensor .unbind :
255- if len (args ) > 1 :
256- dim = args [1 ]
257- elif "dim" in kwargs :
258- dim = kwargs ["dim" ]
259- else :
260- dim = 0
261- if dim == 0 :
262- if metas is None :
263- metas = decollate_batch (args [0 ], detach = False )
264- ret .__dict__ = metas [idx ].__dict__ .copy ()
265- ret .is_batch = False
266-
223+ ret = MetaTensor ._handle_batched (ret , idx , metas , func , args , kwargs )
267224 out .append (ret )
268225 # if the input was a tuple, then return it as a tuple
269226 return tuple (out ) if isinstance (rets , tuple ) else out
270227
228+ @classmethod
229+ def _handle_batched (cls , ret , idx , metas , func , args , kwargs ):
230+ """utility function to handle batched MetaTensors."""
231+ # If we have a batch of data, then we need to be careful if a slice of
232+ # the data is returned. Depending on how the data are indexed, we return
233+ # some or all of the metadata, and the return object may or may not be a
234+ # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
235+ # if indexing e.g., `batch[0]`
236+ if func == torch .Tensor .__getitem__ :
237+ if idx > 0 or len (args ) < 2 or len (args [0 ]) < 1 :
238+ return ret
239+ batch_idx = args [1 ][0 ] if isinstance (args [1 ], Sequence ) else args [1 ]
240+ # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
241+ # first element will be `slice(None, None, None)` and `Ellipsis`,
242+ # respectively. Don't need to do anything with the metadata.
243+ if batch_idx in (slice (None , None , None ), Ellipsis , None ) or isinstance (batch_idx , torch .Tensor ):
244+ return ret
245+ dec_batch = decollate_batch (args [0 ], detach = False )
246+ ret_meta = dec_batch [batch_idx ]
247+ if isinstance (ret_meta , list ) and ret_meta : # e.g. batch[0:2], re-collate
248+ try :
249+ ret_meta = list_data_collate (ret_meta )
250+ except (TypeError , ValueError , RuntimeError , IndexError ) as e :
251+ raise ValueError (
252+ "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
253+ "please consider converting it into a torch Tensor using `x.as_tensor()` or "
254+ "a numpy array using `x.array`."
255+ ) from e
256+ elif isinstance (ret_meta , MetaObj ): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
257+ ret_meta .is_batch = False
258+ if hasattr (ret_meta , "__dict__" ):
259+ ret .__dict__ = ret_meta .__dict__ .copy ()
260+ # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
261+ # But we only want to split the batch if the `unbind` is along the 0th dimension.
262+ elif func == torch .Tensor .unbind :
263+ if len (args ) > 1 :
264+ dim = args [1 ]
265+ elif "dim" in kwargs :
266+ dim = kwargs ["dim" ]
267+ else :
268+ dim = 0
269+ if dim == 0 :
270+ if metas is None :
271+ metas = decollate_batch (args [0 ], detach = False )
272+ if hasattr (metas [idx ], "__dict__" ):
273+ ret .__dict__ = metas [idx ].__dict__ .copy ()
274+ ret .is_batch = False
275+ return ret
276+
271277 @classmethod
272278 def __torch_function__ (cls , func , types , args = (), kwargs = None ) -> Any :
273279 """Wraps all torch functions."""
0 commit comments