@@ -454,16 +454,12 @@ def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
454454 """
455455 # Both are in local data mode
456456 if self .dataset is not None and other .dataset is not None :
457- merged = self ._union_local_data (other )
458- self .dataset = merged .dataset
459- self .metadata = merged .metadata
457+ self ._union_local_data (other )
460458 return self
461459
462460 # Both are in metadata mode
463461 if self .metadata is not None and other .metadata is not None :
464- merged = self ._union_metadata (other )
465- self .dataset = merged .dataset
466- self .metadata = merged .metadata
462+ self ._union_metadata (other )
467463 return self
468464
469465 # Mixed mode: not supported
@@ -473,53 +469,44 @@ def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
473469 "Cannot union batches in different modes (metadata vs local data)" ,
474470 )
475471
476- def _union_metadata (self , other : DistributedBatchMemory ) -> DistributedBatchMemory :
477- """Merge two batches in metadata mode."""
472+ def _union_metadata (self , other : DistributedBatchMemory ) -> None :
473+ """Merge two batches in metadata mode by modifying self in-place ."""
478474 # Combine shards from both batches
479475 all_shards = self .metadata .shards + other .metadata .shards
480476 max_global_step = max (self .metadata .global_step , other .metadata .global_step )
481477
482478 # Calculate total_batch_size (validates different fields have same total)
483479 _ , total_batch_size = self ._group_shards_by_keys (all_shards )
484480
485- # Create new metadata
486- new_metadata = BatchMetadata (
481+ # Update self. metadata directly
482+ self . metadata = BatchMetadata (
487483 batch_id = str (uuid .uuid4 ()),
488484 global_step = max_global_step ,
489485 total_batch_size = total_batch_size ,
490486 shards = all_shards ,
491487 )
488+ self .dataset = None
492489
493- batch = self .__class__ .__new__ (self .__class__ )
494- batch .dataset = None
495- batch .metadata = new_metadata
496- return batch
497-
498- def _union_local_data (
499- self , other : DistributedBatchMemory
500- ) -> DistributedBatchMemory :
501- """Merge two batches in local data mode."""
502- merged_data = {k : v for k , v in self .dataset .items ()}
490+ def _union_local_data (self , other : DistributedBatchMemory ) -> None :
491+ """Merge two batches in local data mode by modifying self in-place."""
492+ # Merge data directly into self.dataset
503493 for k , v in other .dataset .items ():
504- if k in merged_data :
505- if isinstance (merged_data [k ], torch .Tensor ) and isinstance (
494+ if k in self . dataset :
495+ if isinstance (self . dataset [k ], torch .Tensor ) and isinstance (
506496 v , torch .Tensor
507497 ):
508- merged_data [k ] = torch .cat ([merged_data [k ], v ], dim = 0 )
509- elif isinstance (merged_data [k ], list ) and isinstance (v , list ):
510- merged_data [k ] = merged_data [k ] + v
498+ self . dataset [k ] = torch .cat ([self . dataset [k ], v ], dim = 0 )
499+ elif isinstance (self . dataset [k ], list ) and isinstance (v , list ):
500+ self . dataset [k ] = self . dataset [k ] + v
511501 else :
512502 # Handle mixed types or scalar values
513- if isinstance (merged_data [k ], list ):
514- merged_data [k ].append (v )
503+ if isinstance (self . dataset [k ], list ):
504+ self . dataset [k ].append (v )
515505 else :
516- merged_data [k ] = [merged_data [k ], v ]
506+ self . dataset [k ] = [self . dataset [k ], v ]
517507 else :
518- merged_data [k ] = v
519- batch = self .__class__ .__new__ (self .__class__ )
520- batch .dataset = merged_data
521- batch .metadata = None
522- return batch
508+ self .dataset [k ] = v
509+ self .metadata = None
523510
524511 def _get_total_size (self ) -> int :
525512 """Get the total size of the dataset, supporting both tensor and scalar types.
0 commit comments