Skip to content

Commit 9714c82

Browse files
author
daihao
committed
fix
1 parent 0670c88 commit 9714c82

File tree

1 file changed

+20
-33
lines changed

1 file changed

+20
-33
lines changed

areal/controller/batch.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -454,16 +454,12 @@ def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
454454
"""
455455
# Both are in local data mode
456456
if self.dataset is not None and other.dataset is not None:
457-
merged = self._union_local_data(other)
458-
self.dataset = merged.dataset
459-
self.metadata = merged.metadata
457+
self._union_local_data(other)
460458
return self
461459

462460
# Both are in metadata mode
463461
if self.metadata is not None and other.metadata is not None:
464-
merged = self._union_metadata(other)
465-
self.dataset = merged.dataset
466-
self.metadata = merged.metadata
462+
self._union_metadata(other)
467463
return self
468464

469465
# Mixed mode: not supported
@@ -473,53 +469,44 @@ def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
473469
"Cannot union batches in different modes (metadata vs local data)",
474470
)
475471

476-
def _union_metadata(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
477-
"""Merge two batches in metadata mode."""
472+
def _union_metadata(self, other: DistributedBatchMemory) -> None:
473+
"""Merge two batches in metadata mode by modifying self in-place."""
478474
# Combine shards from both batches
479475
all_shards = self.metadata.shards + other.metadata.shards
480476
max_global_step = max(self.metadata.global_step, other.metadata.global_step)
481477

482478
# Calculate total_batch_size (validates different fields have same total)
483479
_, total_batch_size = self._group_shards_by_keys(all_shards)
484480

485-
# Create new metadata
486-
new_metadata = BatchMetadata(
481+
# Update self.metadata directly
482+
self.metadata = BatchMetadata(
487483
batch_id=str(uuid.uuid4()),
488484
global_step=max_global_step,
489485
total_batch_size=total_batch_size,
490486
shards=all_shards,
491487
)
488+
self.dataset = None
492489

493-
batch = self.__class__.__new__(self.__class__)
494-
batch.dataset = None
495-
batch.metadata = new_metadata
496-
return batch
497-
498-
def _union_local_data(
499-
self, other: DistributedBatchMemory
500-
) -> DistributedBatchMemory:
501-
"""Merge two batches in local data mode."""
502-
merged_data = {k: v for k, v in self.dataset.items()}
490+
def _union_local_data(self, other: DistributedBatchMemory) -> None:
491+
"""Merge two batches in local data mode by modifying self in-place."""
492+
# Merge data directly into self.dataset
503493
for k, v in other.dataset.items():
504-
if k in merged_data:
505-
if isinstance(merged_data[k], torch.Tensor) and isinstance(
494+
if k in self.dataset:
495+
if isinstance(self.dataset[k], torch.Tensor) and isinstance(
506496
v, torch.Tensor
507497
):
508-
merged_data[k] = torch.cat([merged_data[k], v], dim=0)
509-
elif isinstance(merged_data[k], list) and isinstance(v, list):
510-
merged_data[k] = merged_data[k] + v
498+
self.dataset[k] = torch.cat([self.dataset[k], v], dim=0)
499+
elif isinstance(self.dataset[k], list) and isinstance(v, list):
500+
self.dataset[k] = self.dataset[k] + v
511501
else:
512502
# Handle mixed types or scalar values
513-
if isinstance(merged_data[k], list):
514-
merged_data[k].append(v)
503+
if isinstance(self.dataset[k], list):
504+
self.dataset[k].append(v)
515505
else:
516-
merged_data[k] = [merged_data[k], v]
506+
self.dataset[k] = [self.dataset[k], v]
517507
else:
518-
merged_data[k] = v
519-
batch = self.__class__.__new__(self.__class__)
520-
batch.dataset = merged_data
521-
batch.metadata = None
522-
return batch
508+
self.dataset[k] = v
509+
self.metadata = None
523510

524511
def _get_total_size(self) -> int:
525512
"""Get the total size of the dataset, supporting both tensor and scalar types.

0 commit comments

Comments
 (0)