diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index efe1c0e64..a8e73f266 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -84,18 +84,18 @@ def chunk_by_ffd(self, group_size: int, dp_size: int) -> list["DistributedBatch" """ raise NotImplementedError() - def union(self, other: "DistributedBatch") -> "DistributedBatch": - """Merge another batch with this one. + def union_(self, other: "DistributedBatch") -> "DistributedBatch": + """In-place merge another batch into this one. Parameters ---------- other : DistributedBatch - Another batch to merge with + Another batch to merge into this batch Returns ------- DistributedBatch - Merged batch + Merged batch (self) """ raise NotImplementedError() diff --git a/areal/controller/batch.py b/areal/controller/batch.py index c5b6f3d48..f19e14055 100644 --- a/areal/controller/batch.py +++ b/areal/controller/batch.py @@ -1,9 +1,24 @@ -from typing import Any +from __future__ import annotations + +import asyncio +import uuid +from collections import defaultdict +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, ClassVar import torch from torch import Tensor from areal.api.controller_api import DistributedBatch +from areal.controller.batch_metadata import ( + BatchMetadata, + ShardId, + ShardMetadata, +) + +if TYPE_CHECKING: + from areal.controller.batch_client import BatchDataClient +from areal.utils import logging from areal.utils.batch_utils import ( convert_dict_to_list, convert_list_to_dict, @@ -13,14 +28,142 @@ from areal.utils.datapack import ffd_allocate from areal.utils.errors import FrameworkError +logger = logging.getLogger("DistributedBatchMemory") + + +class BatchStatus(Enum): + """Explicit status enum for DistributedBatchMemory. + + Attributes + ---------- + LOCAL : auto + Data stored locally in memory + REMOTE : auto + Only metadata; data fetched on-demand via HTTP + EMPTY : auto + Neither data nor metadata present (invalid/empty state) + """ + + LOCAL = auto() # Data stored locally in memory + REMOTE = auto() # Only metadata; data fetched on-demand + EMPTY = auto() # Neither present (invalid state) + class DistributedBatchMemory(DistributedBatch): - dataset = None + """Distributed batch memory with metadata-driven data access. + + This class separates metadata (data shape, location) from actual data. + The control plane only passes metadata, and actual data is fetched on-demand + via HTTP from distributed nodes. + + The class supports two statuses: + - LOCAL status: Data is stored locally in memory (dataset is not None) + - REMOTE status: Only metadata is present; data fetched on-demand via HTTP + + Use the `status` property to check the current status, and `is_local`/`is_remote` + for convenience checks. + + Attributes + ---------- + dataset : dict[str, torch.Tensor | Any] | None + The actual data (lazy-loaded, None until get_data() is called) + metadata : BatchMetadata | None + Metadata describing the distributed batch + """ + + # Shared client for fetching data (singleton pattern) + _client: ClassVar[BatchDataClient | None] = None + + def __init__( + self, + dataset: dict[str, torch.Tensor | Any] | None = None, + metadata: BatchMetadata | None = None, + ): + """Initialize a DistributedBatchMemory instance. + + Parameters + ---------- + dataset : dict[str, torch.Tensor | Any] | None + The actual data stored locally. If provided, batch is in LOCAL status. + metadata : BatchMetadata | None + Metadata describing the distributed batch. If provided without dataset, + batch is in REMOTE status. + """ + self.dataset = dataset + self.metadata = metadata + + @property + def status(self) -> BatchStatus: + """Get the current status of this batch (LOCAL, REMOTE, or EMPTY).""" + has_data = self.dataset is not None and len(self.dataset) > 0 + has_meta = self.metadata is not None + + if has_data: + return BatchStatus.LOCAL + if has_meta: + return BatchStatus.REMOTE + return BatchStatus.EMPTY + + @property + def is_local(self) -> bool: + """Check if data is available locally (no fetch needed).""" + return self.status == BatchStatus.LOCAL + + @property + def is_remote(self) -> bool: + """Check if this batch is in metadata-only status. + + Returns + ------- + bool + True if batch is in REMOTE status + """ + return self.status == BatchStatus.REMOTE + + def _require_status(self, *allowed_statuses: BatchStatus, operation: str) -> None: + """Assert that current status is one of the allowed status.""" + if self.status not in allowed_statuses: + raise FrameworkError( + "FrameworkError", + "BatchStatusError", + f"Operation '{operation}' requires status {[m.name for m in allowed_statuses]}, " + f"but current status is {self.status.name}", + ) + + def _require_same_status( + self, other: DistributedBatchMemory, operation: str + ) -> None: + """Assert that both batches are in the same status.""" + if self.status != other.status: + raise FrameworkError( + "FrameworkError", + "BatchStatusError", + f"Operation '{operation}' requires both batches in same status. " + f"Self is {self.status.name}, other is {other.status.name}", + ) + + @classmethod + def get_client(cls) -> BatchDataClient: + """Get or create the shared batch data client. + + Returns + ------- + BatchDataClient + Shared client instance + """ + if cls._client is None: + # Import here to avoid circular dependency + from areal.controller.batch_client import BatchDataClient + + cls._client = BatchDataClient() + return cls._client @classmethod def from_dict(cls, dict_dataset: dict[str, Tensor | Any]): """Create a DistributedBatchMemory from dictionary format dataset. + This creates a LOCAL status batch with data stored in memory. + Parameters ---------- dict_dataset : Dict[str, Union[Tensor, Any]] @@ -29,12 +172,29 @@ def from_dict(cls, dict_dataset: dict[str, Tensor | Any]): Returns ------- DistributedBatchMemory - New DistributedBatchMemory instance + New DistributedBatchMemory instance in LOCAL status """ validate_dict_dataset(dict_dataset) - instance = cls.__new__(cls) - instance.dataset = dict_dataset - return instance + return cls(dataset=dict_dataset, metadata=None) + + @classmethod + def from_metadata(cls, metadata: BatchMetadata) -> DistributedBatchMemory: + """Create a DistributedBatchMemory from metadata (without actual data). + + This creates a REMOTE status batch. The data will be fetched lazily + when get_data() is called. + + Parameters + ---------- + metadata : BatchMetadata + Metadata describing the distributed batch + + Returns + ------- + DistributedBatchMemory + New DistributedBatchMemory instance in REMOTE status + """ + return cls(dataset=None, metadata=metadata) @classmethod def from_list(cls, list_dataset: list[dict[str, Tensor | Any]]): @@ -53,18 +213,36 @@ def from_list(cls, list_dataset: list[dict[str, Tensor | Any]]): dict_dataset = convert_list_to_dict(list_dataset) return cls.from_dict(dict_dataset) - def chunk(self, dp_size: int) -> list["DistributedBatchMemory"]: + def chunk(self, dp_size: int) -> list[DistributedBatchMemory]: """Split the dataset across data parallel processes. This function preserves the original order of data, ensuring that the sequence of samples in the concatenated result matches the - original dataset order.""" - if not self.dataset: - raise FrameworkError( - "FrameworkError", - "DistributedBatchMemoryError", - "Cannot split empty dataset", - ) + original dataset order. + + Supports both REMOTE status (metadata) and LOCAL status (data). + + Parameters + ---------- + dp_size : int + Number of data parallel processes + + Returns + ------- + list[DistributedBatchMemory] + List of chunked batches + + Raises + ------ + FrameworkError + If batch is in EMPTY status + """ + # REMOTE status: split shards across dp_size groups + if self.is_remote: + return self._chunk_metadata(dp_size) + + # LOCAL status: split actual data + self._require_status(BatchStatus.LOCAL, operation="chunk") total = self._get_total_size() part_size = (total + dp_size - 1) // dp_size @@ -81,14 +259,74 @@ def chunk(self, dp_size: int) -> list["DistributedBatchMemory"]: else: # For scalar values, keep as-is split_data[k] = v - batch = self.__class__.__new__(self.__class__) - batch.dataset = split_data - batches.append(batch) + batches.append(self.__class__(dataset=split_data, metadata=None)) + return batches + + def _group_shards_by_task_id( + self, shards: list[ShardMetadata] + ) -> dict[str, list[ShardMetadata]]: + """Group shards by task_id.""" + task_id_to_shards: dict[str, list[ShardMetadata]] = defaultdict(list) + for shard in shards: + task_id = shard.shard_id.task_id + task_id_to_shards[task_id].append(shard) + return task_id_to_shards + + def _chunk_metadata(self, dp_size: int) -> list[DistributedBatchMemory]: + """Split metadata across data parallel processes. + + Groups shards by task_id and distributes task groups across dp_size processes. + """ + + if self.metadata is None: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "No metadata to split", + ) + + if dp_size <= 0: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "dp_size must be positive", + ) + + task_id_to_shards = self._group_shards_by_task_id(self.metadata.shards) + task_groups = list(task_id_to_shards.items()) + + if not task_groups: + batches = [] + for i in range(dp_size): + new_metadata = BatchMetadata( + batch_id=f"{self.metadata.batch_id}_chunk_{i}", + shards=[], + ) + batches.append(self.__class__(dataset=None, metadata=new_metadata)) + return batches + + # Distribute task_id groups across dp_size processes + shards_per_dp_rank: list[list[ShardMetadata]] = [[] for _ in range(dp_size)] + task_id_counts_per_rank = [0] * dp_size + + for task_id, shards in task_groups: + target_rank = min(range(dp_size), key=lambda i: task_id_counts_per_rank[i]) + shards_per_dp_rank[target_rank].extend(shards) + task_id_counts_per_rank[target_rank] += 1 + + batches = [] + for i in range(dp_size): + new_metadata = BatchMetadata( + batch_id=f"{self.metadata.batch_id}_chunk_{i}", + shards=shards_per_dp_rank[i], + ) + batches.append(self.__class__(dataset=None, metadata=new_metadata)) + return batches def chunk_by_ffd( self, group_size: int, dp_size: int - ) -> list["DistributedBatchMemory"]: + ) -> list[DistributedBatchMemory]: """Split data by sequence length using First Fit Decreasing algorithm Parameters @@ -102,7 +340,20 @@ def chunk_by_ffd( ------- list[DistributedBatchMemory] List of DistributedBatchMemory objects + + Notes + ----- + For REMOTE status, this method will fall back to simple chunking + since we cannot determine sequence lengths without fetching the data. """ + # REMOTE status: fall back to simple chunking + # TODO: FFD requires sequence length information which is not available in metadata yet + if self.is_remote: + return self.chunk(dp_size) + + # LOCAL status: use FFD algorithm + self._require_status(BatchStatus.LOCAL, operation="chunk_by_ffd") + total_size = self._get_total_size() if total_size % group_size != 0: raise FrameworkError( @@ -165,42 +416,67 @@ def chunk_by_ffd( else: # For scalar values, keep as-is (they represent single sample) split_data[k] = v - batch = self.__class__.__new__(self.__class__) - batch.dataset = split_data - batches.append(batch) + batches.append(self.__class__(dataset=split_data, metadata=None)) return batches - def union(self, other: "DistributedBatchMemory") -> "DistributedBatchMemory": - """Merge another batch with this one""" - merged_data = {k: v for k, v in self.dataset.items()} + def union_(self, other: DistributedBatchMemory) -> DistributedBatchMemory: + """In-place merge. Mutates ``self`` and returns it.""" + self._require_same_status(other, operation="union_") + + if self.is_remote: + self._union_metadata(other) + else: + self._union_local_data(other) + + return self + + def _union_metadata(self, other: DistributedBatchMemory) -> None: + """Merge two batches in metadata status by modifying self in-place.""" + all_shards = self.metadata.shards + other.metadata.shards + self.metadata = BatchMetadata( + batch_id=str(uuid.uuid4()), + shards=all_shards, + ) + self.dataset = None + + def _union_local_data(self, other: DistributedBatchMemory) -> None: + """Merge two batches in local data status by modifying self in-place.""" for k, v in other.dataset.items(): - if k in merged_data: - if isinstance(merged_data[k], torch.Tensor) and isinstance( + if k in self.dataset: + if isinstance(self.dataset[k], torch.Tensor) and isinstance( v, torch.Tensor ): - merged_data[k] = torch.cat([merged_data[k], v], dim=0) - elif isinstance(merged_data[k], list) and isinstance(v, list): - merged_data[k] = merged_data[k] + v + self.dataset[k] = torch.cat([self.dataset[k], v], dim=0) + elif isinstance(self.dataset[k], list) and isinstance(v, list): + self.dataset[k] = self.dataset[k] + v else: # Handle mixed types or scalar values - if isinstance(merged_data[k], list): - merged_data[k].append(v) + if isinstance(self.dataset[k], list): + self.dataset[k].append(v) else: - merged_data[k] = [merged_data[k], v] + self.dataset[k] = [self.dataset[k], v] else: - merged_data[k] = v - batch = self.__class__.__new__(self.__class__) - batch.dataset = merged_data - return batch + self.dataset[k] = v + self.metadata = None def _get_total_size(self) -> int: - """Get the total size of the dataset, supporting both tensor and scalar types. + """Get the total size of the dataset.""" + if self.metadata is not None: + if not self.metadata.shards: + return 0 + task_id_to_shards = self._group_shards_by_task_id(self.metadata.shards) + + total_size = 0 + for shards in task_id_to_shards.values(): + if ( + shards + and shards[0].tensor_metadata + and shards[0].tensor_metadata.shape + ): + total_size += shards[0].tensor_metadata.shape[0] + + return total_size - Returns - ------- - int - The total size (batch size) of the dataset - """ if not self.dataset: return 0 @@ -210,65 +486,243 @@ def _get_total_size(self) -> int: elif isinstance(first_value, list): return len(first_value) else: - # For scalar values, assume it's a single sample return 1 + def _merge_shards(self, shard_data_list: list[dict[str, Any]]) -> dict[str, Any]: + """Merge shard data into a complete dataset.""" + if not shard_data_list: + return {} + + all_keys = set() + for shard_data in shard_data_list: + all_keys.update(shard_data.keys()) + + same_keys = all( + set(shard_data.keys()) == all_keys for shard_data in shard_data_list + ) + + if same_keys and "attention_mask" in all_keys: + return concat_padded_tensors(shard_data_list) + else: + return self._merge_shards_with_different_keys(shard_data_list, all_keys) + + def _merge_shards_with_different_keys( + self, + shard_data_list: list[dict[str, Any]], + all_keys: set[str], + ) -> dict[str, Any]: + """Merge shards that may have different keys. + + Handles padding for tensors with different sequence lengths. + """ + result = {} + + for key in sorted(all_keys): + values_to_concat = [] + + for shard_data in shard_data_list: + if key in shard_data: + values_to_concat.append(shard_data[key]) + + if not values_to_concat: + continue + + first_value = values_to_concat[0] + if first_value.ndim > 1: + max_length = max(tensor.shape[1] for tensor in values_to_concat) + need_padding = any( + tensor.shape[1] < max_length for tensor in values_to_concat + ) + + if need_padding: + pad_value = 0 if key == "attention_mask" else 0.0 + padded_tensors = [] + for tensor in values_to_concat: + if tensor.shape[1] < max_length: + pad_width = max_length - tensor.shape[1] + n_dim = tensor.ndim + pad_mode = (0,) * (2 * (n_dim - 2)) + (0, pad_width) + padded_tensors.append( + torch.nn.functional.pad( + tensor, pad_mode, value=pad_value + ) + ) + else: + padded_tensors.append(tensor) + result[key] = torch.cat(padded_tensors, dim=0) + else: + result[key] = torch.cat(values_to_concat, dim=0) + else: + result[key] = torch.cat(values_to_concat, dim=0) + + return result + def get_data(self) -> dict[str, torch.Tensor | Any]: """Get all data from the DistributedBatchMemory. + If data is stored locally, returns it directly. + If data is distributed (has metadata), fetches it from remote nodes + via HTTP and assembles the complete dataset. + Returns ------- - Dict[str, torch.Tensor] - Dictionary where keys are field names and values are 1D tensors - containing all values for that field across the entire batch. - Each tensor is formed by concatenating all individual values - for the corresponding field. + Dict[str, torch.Tensor | Any] + Dictionary where keys are field names and values are tensors or + other data types containing all values for that field across the + entire batch. """ - if not self.dataset: - return {} + if self.dataset is not None: + return self.dataset - # Get all attributes from the first sample - first_item = self[0] - attrs = list(first_item.keys()) + if self.metadata is not None: + client = self.get_client() - # Construct attr -> tensor mapping - batch_data = {} - for attr in attrs: - batch_data[attr] = self[attr] + async def _fetch_shards(): + shard_data_list = await client.fetch_shards(self.metadata) + return self._merge_shards(shard_data_list) - return batch_data + try: + self.dataset = asyncio.run(_fetch_shards()) + except RuntimeError as exc: + raise RuntimeError( + "get_data() cannot be called from within an async context when " + "fetching remote data. Please call aget_data() instead." + ) from exc + return self.dataset - @staticmethod - def concat(data: list["DistributedBatchMemory"]) -> "DistributedBatchMemory": - """Concatenate multiple DistributedBatchMemory objects + return {} - Parameters - ---------- - data : list[DistributedBatchMemory] - List of DistributedBatchMemory objects to concatenate + async def aget_data(self) -> dict[str, torch.Tensor | Any]: + """Async version of get_data(). + + This is useful when calling from async contexts to avoid blocking. Returns ------- - DistributedBatchMemory - Single concatenated DistributedBatchMemory object + Dict[str, torch.Tensor | Any] + Dictionary where keys are field names and values are tensors or + other data types. """ - if not data: - batch = DistributedBatchMemory.__new__(DistributedBatchMemory) - batch.dataset = {} - return batch - - merged_data = concat_padded_tensors([k.dataset for k in data]) - result = DistributedBatchMemory.__new__(DistributedBatchMemory) - result.dataset = merged_data - return result + if self.dataset is not None: + return self.get_data() + + if self.metadata is not None: + client = self.get_client() + shard_data_list = await client.fetch_shards(self.metadata) + self.dataset = self._merge_shards(shard_data_list) + return self.dataset + + return {} + + @classmethod + def concat(cls, data: list[DistributedBatchMemory]) -> DistributedBatchMemory: + """Concatenate multiple DistributedBatchMemory objects.""" + assert data is not None and len(data) != 0 + + has_metadata = [item.is_remote for item in data] + if not all(has_metadata) and any(has_metadata): + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Cannot concatenate batches with mixed statuses. " + "All batches must be either in REMOTE status or LOCAL status.", + ) + + if all(item.is_remote for item in data): + all_shards = [] + for item in data: + all_shards.extend(item.metadata.shards) + + return cls( + dataset=None, + metadata=BatchMetadata( + batch_id=str(uuid.uuid4()), + shards=all_shards, + ), + ) + + datasets = [k.dataset for k in data] + if not datasets: + merged_data = {} + else: + all_keys = set() + for dataset in datasets: + all_keys.update(dataset.keys()) + + same_keys = all(set(dataset.keys()) == all_keys for dataset in datasets) + if not same_keys: + key_sets = [set(dataset.keys()) for dataset in datasets] + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + f"All datasets must have the same keys. " + f"Found key sets: {[sorted(ks) for ks in key_sets]}", + ) + + merged_data = concat_padded_tensors(datasets) + + return cls(dataset=merged_data, metadata=None) + + @classmethod + async def aclear( + cls, + target: DistributedBatchMemory | list[str], + node_addrs: set[str], + ): + """Clear old batch data from distributed nodes.""" + if not node_addrs: + return + + client = cls.get_client() + + # Extract shard_ids from target + if isinstance(target, DistributedBatchMemory): + if target.metadata is None or not target.metadata.shards: + return + shard_ids = [shard.shard_id for shard in target.metadata.shards] + elif isinstance(target, list): + shard_ids = [ShardId.from_string(s) for s in target] + else: + raise TypeError( + f"target must be DistributedBatchMemory or list[str], got {type(target)}" + ) + + await client.clear_batches(node_addrs, shard_ids) + + @classmethod + def clear( + cls, + target: DistributedBatchMemory | list[str], + node_addrs: set[str], + ): + """Synchronous version of clear().""" + asyncio.run(cls.aclear(target, node_addrs)) def __getstate__(self): - return {"dataset": self.dataset} + return { + "dataset": self.dataset, + "metadata": self.metadata, + } def __setstate__(self, state): - self.dataset = state["dataset"] + self.dataset = state.get("dataset") + self.metadata = state.get("metadata") def __getitem__(self, key): + if self.is_remote: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Cannot access items in REMOTE status. Call get_data() first to fetch data.", + ) + + if self.dataset is None: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Dataset is empty.", + ) + if isinstance(key, int): return {k: v[key] for k, v in self.dataset.items()} elif isinstance(key, str): @@ -281,46 +735,66 @@ def __getitem__(self, key): ) def __setitem__(self, key, value): - """Support two assignment methods: - - str key: update entire attribute tensor - - int index: requires converting data to list format for update (less efficient, avoid if possible) - """ - if isinstance(key, str): - # Update entire attribute tensor or scalar/list value - if self.dataset: - expected_total_size = self._get_total_size() - if isinstance(value, torch.Tensor): - if value.shape[0] != expected_total_size: - raise FrameworkError( - "FrameworkError", - "DistributedBatchMemoryError", - f"The batch size of the tensor does not match. Expected {expected_total_size}, actual {value.shape[0]}", - ) - elif isinstance(value, list): - if len(value) != expected_total_size: - raise FrameworkError( - "FrameworkError", - "DistributedBatchMemoryError", - f"The batch size of the list does not match. Expected {expected_total_size}, actual {len(value)}", - ) - self.dataset[key] = value - else: + if not isinstance(key, str): raise FrameworkError( "FrameworkError", "DistributedBatchMemoryError", "key must be str" ) + if isinstance(value, DistributedBatchMemory): + # Merge using in-place union + self.union_(value) + return + + if self.metadata is not None: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Cannot assign regular value to metadata-status batch. " + "Use union() with a DistributedBatchMemory object, or get_data() first.", + ) + + if self.dataset: + expected_total_size = self._get_total_size() + if isinstance(value, torch.Tensor): + if value.shape[0] != expected_total_size: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + f"The batch size of the tensor does not match. Expected {expected_total_size}, actual {value.shape[0]}", + ) + elif isinstance(value, list): + if len(value) != expected_total_size: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + f"The batch size of the list does not match. Expected {expected_total_size}, actual {len(value)}", + ) + + if self.dataset is None: + self.dataset = {} + self.dataset[key] = value + def __delitem__(self, key): - """Support two deletion methods: - - int index: delete sample at specified position - - str key: delete entire attribute - """ + """Delete item by int index or str key.""" + if self.is_remote: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Cannot delete items in REMOTE status. Call get_data() first to fetch data.", + ) + + if self.dataset is None: + raise FrameworkError( + "FrameworkError", + "DistributedBatchMemoryError", + "Dataset is empty.", + ) + if isinstance(key, int): - # Convert to list format for deletion list_dataset = convert_dict_to_list(self.dataset) del list_dataset[key] self.dataset = convert_list_to_dict(list_dataset) elif isinstance(key, str): - # Delete entire attribute directly if key in self.dataset: del self.dataset[key] else: @@ -331,10 +805,21 @@ def __delitem__(self, key): ) def __str__(self): - if not self.dataset: - return "DistributedBatchMemory" - + status_name = self.status.name total_size = self._get_total_size() + + if self.status == BatchStatus.EMPTY: + return f"DistributedBatchMemory" + + if self.is_remote: + return ( + f"DistributedBatchMemory" + ) + keys = list(self.dataset.keys()) shapes = {} for k, v in self.dataset.items(): @@ -344,7 +829,10 @@ def __str__(self): shapes[k] = f"list[{len(v)}]" else: shapes[k] = f"scalar({type(v).__name__})" - return f"DistributedBatchMemory" + return ( + f"DistributedBatchMemory" + ) def __len__(self): """Return the total size.""" diff --git a/areal/controller/batch_client.py b/areal/controller/batch_client.py new file mode 100644 index 000000000..56f686b9f --- /dev/null +++ b/areal/controller/batch_client.py @@ -0,0 +1,212 @@ +"""HTTP client for distributed batch memory retrieval.""" + +import asyncio +from collections import defaultdict +from typing import Any + +import aiohttp +import orjson +import torch + +from areal.controller.batch_metadata import BatchMetadata, ShardId, ShardMetadata +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging + +logger = logging.getLogger("BatchClient") + +# Default connection limit for batch data fetching +DEFAULT_CONNECTION_LIMIT = 1000 + + +class BatchDataClient: + """HTTP client for fetching distributed batch data.""" + + def __init__( + self, + timeout: float = 300.0, + connection_limit: int = DEFAULT_CONNECTION_LIMIT, + connect_timeout: float | None = None, + read_timeout: float | None = None, + retries: int = 2, + backoff_factor: float = 0.5, + ): + # Split timeout so we can surface slow connects vs slow reads. + self.timeout = aiohttp.ClientTimeout( + total=timeout, connect=connect_timeout, sock_read=read_timeout + ) + self.connection_limit = connection_limit + self.retries = retries + self.backoff_factor = backoff_factor + + async def fetch_shard( + self, session: aiohttp.ClientSession, shard: ShardMetadata + ) -> dict[str, Any]: + """Fetch a shard from a node.""" + url = f"http://{shard.node_addr}/data/{shard.shard_id}" + params = {} + + last_exc: Exception | None = None + for attempt in range(self.retries + 1): + try: + async with session.get( + url, params=params, timeout=self.timeout + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError( + f"Failed to fetch shard {shard.shard_id} from {shard.node_addr}: " + f"HTTP {response.status} - {error_text}" + ) + + data_bytes = await response.read() + serialized_data = orjson.loads(data_bytes) + data = deserialize_value(serialized_data) + + assert isinstance(data, torch.Tensor) + result = {shard.shard_id.key: data} + + logger.info( + f"Fetched shard {shard.shard_id} from {shard.node_addr} " + f"({len(data_bytes)} bytes)" + ) + return result + + except asyncio.TimeoutError: + last_exc = RuntimeError( + f"Timeout fetching shard {shard.shard_id} from {shard.node_addr}" + ) + except Exception as e: + last_exc = RuntimeError( + f"Error fetching shard {shard.shard_id} from {shard.node_addr}: {e}" + ) + + if attempt < self.retries: + delay = self.backoff_factor * (2**attempt) + logger.warning( + f"Retrying shard {shard.shard_id} from {shard.node_addr} after attempt " + f"{attempt + 1}/{self.retries + 1} failed: {last_exc}; sleep {delay}s" + ) + await asyncio.sleep(delay) + + assert last_exc is not None + raise last_exc + + async def fetch_shards(self, metadata: BatchMetadata) -> list[dict[str, Any]]: + """Fetch all shards for a batch. + + Shards with the same task_id are grouped together into a single dict. + Returns a list of dicts, where each dict contains all data for one task_id. + """ + if not metadata.shards: + return [] + + connector = aiohttp.TCPConnector(limit=self.connection_limit) + async with aiohttp.ClientSession( + timeout=self.timeout, connector=connector + ) as session: + logger.info( + f"Fetching {len(metadata.shards)} shards for batch {metadata.batch_id}" + ) + tasks = [self.fetch_shard(session, shard) for shard in metadata.shards] + shard_results = await asyncio.gather(*tasks, return_exceptions=True) + + task_data_map: dict[str, dict[str, Any]] = defaultdict(dict) + failures: list[str] = [] + + for shard, shard_result in zip(metadata.shards, shard_results): + if isinstance(shard_result, Exception): + failures.append( + f"{shard.shard_id}@{shard.node_addr}: {shard_result}" + ) + continue + + task_id = shard.shard_id.task_id + task_data_map[task_id].update(shard_result) + + if failures: + raise RuntimeError( + "Failed to fetch shards: " + "; ".join(sorted(failures)) + ) + + return list(task_data_map.values()) + + async def store_shard( + self, + session: aiohttp.ClientSession, + node_addr: str, + shard_id: ShardId, + data: torch.Tensor, + ) -> None: + """Store a shard on a node.""" + url = f"http://{node_addr}/data/{shard_id}" + + serialized_data = serialize_value(data) + data_bytes = orjson.dumps(serialized_data) + + try: + async with session.put( + url, + data=data_bytes, + headers={"Content-Type": "application/octet-stream"}, + timeout=self.timeout, + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError( + f"Failed to store shard {shard_id} to {node_addr}: " + f"HTTP {response.status} - {error_text}" + ) + + logger.debug( + f"Stored shard {shard_id} to {node_addr} ({len(data_bytes)} bytes)" + ) + + except asyncio.TimeoutError as e: + raise RuntimeError( + f"Timeout storing shard {shard_id} to {node_addr}" + ) from e + except Exception as e: + raise RuntimeError( + f"Error storing shard {shard_id} to {node_addr}: {e}" + ) from e + + async def clear_batches( + self, node_addrs: set[str], shard_ids: list[ShardId] + ) -> None: + """Clear specific shards on multiple nodes.""" + connector = aiohttp.TCPConnector(limit=self.connection_limit) + async with aiohttp.ClientSession( + timeout=self.timeout, connector=connector + ) as session: + tasks = [ + self._clear_node(session, node_addr, shard_ids) + for node_addr in node_addrs + ] + await asyncio.gather(*tasks, return_exceptions=True) + + async def _clear_node( + self, session: aiohttp.ClientSession, node_addr: str, shard_ids: list[ShardId] + ) -> None: + """Clear specific shards on a single node.""" + url = f"http://{node_addr}/data/clear" + + shard_id_strings = [str(shard_id) for shard_id in shard_ids] + + try: + async with session.delete( + url, json={"shard_ids": shard_id_strings}, timeout=self.timeout + ) as response: + if response.status != 200: + error_text = await response.text() + logger.warning( + f"Failed to clear data on {node_addr}: " + f"HTTP {response.status} - {error_text}" + ) + else: + result = await response.json() + logger.debug( + f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}" + ) + + except Exception as e: + logger.warning(f"Error clearing data on {node_addr}: {e}") diff --git a/areal/controller/batch_metadata.py b/areal/controller/batch_metadata.py new file mode 100644 index 000000000..173bda378 --- /dev/null +++ b/areal/controller/batch_metadata.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class TensorMetadata: + """Metadata for a tensor field.""" + + shape: tuple[int, ...] + dtype: str + device: str = "cpu" + + def __repr__(self) -> str: + return f"TensorMetadata(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + +@dataclass +class ShardId: + """Identifier for a shard, composed of task_id and key.""" + + task_id: str + key: str + + def __str__(self) -> str: + return f"{self.task_id}:{self.key}" + + def __repr__(self) -> str: + return f"ShardId(task_id={self.task_id}, key={self.key})" + + def __hash__(self) -> int: + return hash((self.task_id, self.key)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShardId): + return False + return self.task_id == other.task_id and self.key == other.key + + @classmethod + def from_string(cls, s: str, default_key: str = "default") -> ShardId: + if ":" in s: + parts = s.split(":", 1) + return cls(task_id=parts[0], key=parts[1]) + return cls(task_id=s, key=default_key) + + +@dataclass +class ShardMetadata: + """Metadata for a single (sub-)shard stored on one node. + + A logical batch can be composed of multiple shards, and a single physical + shard can be split into multiple logical sub-shards via offset and batch_size. + """ + + node_id: str + node_addr: str + shard_id: ShardId + tensor_metadata: TensorMetadata | None = None + + def __repr__(self) -> str: + return ( + f"ShardMetadata(node_id={self.node_id}, node_addr={self.node_addr}, " + f"shard_id={self.shard_id}, tensor_metadata={self.tensor_metadata})" + ) + + +@dataclass +class BatchMetadata: + """Metadata for a distributed batch sharded across multiple nodes.""" + + batch_id: str + shards: list[ShardMetadata] = field(default_factory=list) + + def __repr__(self) -> str: + return ( + f"BatchMetadata(batch_id={self.batch_id}, num_shards={len(self.shards)}, " + f"shards={self.shards})" + ) + + def get_all_node_addrs(self) -> set[str]: + """Get all unique node addresses in this batch.""" + return {shard.node_addr for shard in self.shards} diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 6e214603e..195515891 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -2,6 +2,7 @@ import asyncio import time +import uuid from collections import defaultdict from collections.abc import Callable from dataclasses import asdict, dataclass @@ -20,7 +21,7 @@ from areal.core.staleness_manager import StalenessManager from areal.core.workflow_executor import BatchTaskDispatcher from areal.utils import logging, perf_tracer -from areal.utils.data import concat_padded_tensors, cycle_dataloader +from areal.utils.data import cycle_dataloader from areal.utils.dynamic_import import import_from_string from areal.utils.perf_tracer import trace_perf @@ -35,13 +36,13 @@ class _RemoteRolloutTaskInput: workflow: str workflow_kwargs: dict[str, Any] should_accept_fn: str - task_id: int | None = None + task_id: str | None = None @dataclass class _RemoteRolloutResult: - trajectory: dict[str, Any] - task_id: int | None = None + trajectory: dict[str, Any] | DistributedBatchMemory + task_id: str | None = None class RolloutController: @@ -294,6 +295,8 @@ async def _submit_then_wait() -> _RemoteRolloutResult | None: timeout=0.1, # A short time to prevent blocking other requests raise_timeout=False, http_timeout=self.config.request_timeout, + return_distributed_batch=True, + task_id=task_id, ) # TimeourError will be catched below @@ -352,8 +355,8 @@ def submit( workflow=workflow_str, workflow_kwargs=workflow_kwargs, should_accept_fn=should_accept_fn, - # NOTE: For now we don't trace tasks at the controller level - task_id=None, + # Generate a UUID for tracing task lifecycle + task_id=uuid.uuid4().hex, ) # Delegate to dispatcher @@ -361,12 +364,13 @@ def submit( def wait( self, count: int, timeout: float | None = None, raise_timeout: bool = True - ) -> list[dict[str, Any] | None]: + ) -> list[dict[str, Any] | DistributedBatchMemory | None]: # Delegate to dispatcher and extract trajectories results = self.dispatcher.wait_results(count, timeout, raise_timeout) # Log and trace if self.config.enable_rollout_tracing: logger.info("Rollout results are ready!") + return [r.trajectory if r is not None else None for r in results] @trace_perf("rollout_controller.rollout_batch", category="scheduler") @@ -376,7 +380,7 @@ def rollout_batch( workflow: RolloutWorkflow | type[RolloutWorkflow] | str, workflow_kwargs: dict[str, Any] | None = None, should_accept_fn: str | None = None, - ) -> dict[str, Any]: + ) -> DistributedBatchMemory: perf_tracer.instant( "rollout_controller.rollout_batch", category="scheduler", @@ -390,12 +394,11 @@ def rollout_batch( should_accept_fn=should_accept_fn, ) results = self.wait(count=len(data)) - # Concatenate into batch tensor format - batch = concat_padded_tensors([r for r in results if r is not None]) + batches = [b for b in results if isinstance(b, DistributedBatchMemory)] + if not batches: + return DistributedBatchMemory.from_dict({}) - # NOTE: DistributedBatchMemory.from_dict does nothing for now - # Just for sync with internal code - return DistributedBatchMemory.from_dict(batch) + return DistributedBatchMemory.concat(batches) @trace_perf("rollout_controller.prepare_batch", category="scheduler") def prepare_batch( @@ -404,7 +407,7 @@ def prepare_batch( workflow: RolloutWorkflow | type[RolloutWorkflow] | str, workflow_kwargs: dict[str, Any] | None = None, should_accept_fn: str | None = None, - ): + ) -> DistributedBatchMemory: """Prepare a batch with controlled staleness. Continuously submits from dataloader and waits for results, ensuring at least @@ -425,7 +428,7 @@ def task_input_generator(): workflow=workflow_str, workflow_kwargs=workflow_kwargs, should_accept_fn=should_accept_fn, - task_id=None, + task_id=uuid.uuid4().hex, ) if not hasattr(self, "data_generator"): @@ -437,13 +440,15 @@ def task_input_generator(): self.data_generator, batch_size=dataloader.batch_size ) - # Extract trajectories and concatenate + # Extract trajectories trajectories = [r.trajectory if r is not None else None for r in results] - batch = concat_padded_tensors([t for t in trajectories if t is not None]) - # NOTE: DistributedBatchMemory.from_dict does nothing for now - # Just for sync with internal code - return DistributedBatchMemory.from_dict(batch) + # Filter out None and only keep DistributedBatchMemory instances + batches = [t for t in trajectories if isinstance(t, DistributedBatchMemory)] + if not batches: + return DistributedBatchMemory.from_dict({}) + + return DistributedBatchMemory.concat(batches) async def agenerate(self, req: ModelRequest) -> ModelResponse: """Asynchronously generate a response for the given request. @@ -540,3 +545,11 @@ def dispatcher( def runner(self): """For backward compatibility. The runner is now owned by the dispatcher.""" return self.dispatcher.runner + + # ==================== DISTRIBUTED BATCH RPC WRAPPERS ==================== + def clear_batches(self, target: DistributedBatchMemory | list[str]): + """Clear shard data on workers.""" + server_addrs = { + f"{worker.ip}:{worker.worker_ports[0]}" for worker in self.workers + } + DistributedBatchMemory.clear(target, server_addrs) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 5a6188363..d37414ebf 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -149,7 +149,6 @@ def initialize( # Identify DP head workers self._identify_dp_heads() - logger.info("TrainController initialization complete") def _run_async_task(self, task): @@ -190,7 +189,7 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): worker_id=worker.id, method="create_process_group", parallel_strategy=self.parallel_strategy, - _should_bcast=False, + should_broadcast=False, ) for worker in self.workers ] @@ -201,7 +200,7 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): worker_id=worker.id, method="initialize", ft_spec=ft_spec, - _should_bcast=False, + should_broadcast=False, **kwargs, ) for worker in self.workers @@ -318,16 +317,9 @@ async def _call_with_dispatched_inputs( k: splits[dp_idx] for k, splits in dp_worker_kwargs.items() } - # Convert DistributedBatch to dict for RPC serialization - # TODO: Consider passing metadata instead of full tensors to reduce - # network overhead, especially for large batches - worker_args = [ - arg.get_data() if isinstance(arg, DistributedBatch) else arg - for arg in worker_args - ] + worker_args = [self._serialize_arg_for_rpc(arg) for arg in worker_args] worker_kwargs = { - k: v.get_data() if isinstance(v, DistributedBatch) else v - for k, v in worker_kwargs.items() + k: self._serialize_arg_for_rpc(v) for k, v in worker_kwargs.items() } dp_idx += 1 else: @@ -346,10 +338,29 @@ async def _call_with_dispatched_inputs( ) return await asyncio.gather(*tasks) + def _serialize_arg_for_rpc(self, arg: Any) -> Any: + """Serialize argument for RPC transmission.""" + if isinstance(arg, DistributedBatch): + # If batch has metadata and batch server is enabled, pass metadata + if hasattr(arg, "metadata") and arg.metadata is not None: + # Return a special dict that indicates metadata mode + return { + "__distributed_batch_metadata__": True, + "metadata": arg.metadata, + } + else: + # Legacy mode: get actual data + return arg.get_data() + return arg + def _merge_results(self, results, method): """Merge results from DP heads: pad tensors to max seq_len, concat dicts, return first for others.""" first_result = results[0] + # Handle DistributedBatchMemory + if isinstance(first_result, DistributedBatchMemory): + return DistributedBatchMemory.concat(results) + if isinstance(first_result, torch.Tensor): # Pad tensors to max sequence length and concatenate along batch dimension # Assumes tensor shape is [batch_size, seq_len, ...] @@ -399,7 +410,8 @@ def _align_batches_with_dp( ) -> list[DistributedBatch]: """Split batch across DP groups. Uses chunk_by_ffd if rebalance=True, else simple chunking.""" # Handle empty batch by replicating to all DP groups - if len(input_.get_data()) == 0: + # Use _get_total_size() to avoid fetching data + if len(input_) == 0: return [input_] * self.alloc_mode.train.dp_size if rebalance: @@ -722,3 +734,13 @@ def update_weights(self, meta: WeightUpdateMeta): self._update_weights_from_disk(meta) else: raise ValueError(f"Unknown weight update type {meta.type}") + + # ==================== DISTRIBUTED BATCH RPC WRAPPERS ==================== + def clear_batches(self, target: DistributedBatchMemory | list[str] | None): + """Clear specified shard data on workers.""" + server_addrs = { + f"{worker.ip}:{worker.worker_ports[0]}" for worker in self.workers + } + if target is None: + return + DistributedBatchMemory.clear(target, server_addrs) diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 9dfa22469..2ccc6e2e3 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -1,12 +1,31 @@ import argparse +import asyncio +import logging as stdlib_logging import os +import socket import traceback -from concurrent.futures import Future - -from flask import Flask, jsonify, request +import uuid +from collections import defaultdict +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from queue import Queue +from threading import Lock, Thread +from typing import Any + +import orjson +import torch +from flask import Flask, Response, jsonify, request +from torch import Tensor from areal.api.cli_args import BaseExperimentConfig from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.controller.batch import DistributedBatchMemory +from areal.controller.batch_metadata import ( + BatchMetadata, + ShardId, + ShardMetadata, + TensorMetadata, +) from areal.platforms import current_platform from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging, name_resolve, seeding @@ -21,10 +40,80 @@ # Global engine instance - must be TrainEngine or InferenceEngine _engine: TrainEngine | InferenceEngine | None = None +# Global batch data storage for distributed batch memory +# Storage: shard_id -> dict[str, Tensor] +_batch_storage: dict[ShardId, Tensor] = {} +_batch_storage_lock = Lock() +_batch_storage_stats: dict[ShardId, int] = defaultdict(int) + +# NCCL worker thread for executing non-/data/ endpoints in a single thread +# This ensures NCCL compatibility while allowing /data/ requests to be processed concurrently +_nccl_worker_thread: Thread | None = None +_nccl_work_queue: Queue[tuple[Callable, tuple, dict, Future]] | None = None +_nccl_worker_lock = Lock() + +# Server address (set at startup) +_server_host: str = "0.0.0.0" +_server_port: int = 8000 + # Create Flask app app = Flask(__name__) +def _init_nccl_worker(): + """Initialize the NCCL worker thread for executing non-/data/ endpoints.""" + global _nccl_worker_thread, _nccl_work_queue + + with _nccl_worker_lock: + if _nccl_worker_thread is not None and _nccl_worker_thread.is_alive(): + return # Already initialized + + _nccl_work_queue = Queue() + + def nccl_worker(): + """Worker thread that executes non-/data/ endpoints sequentially.""" + logger.info("NCCL worker thread started") + while True: + try: + work_item = _nccl_work_queue.get() + if work_item is None: # Shutdown signal + logger.info("NCCL worker thread shutting down") + break + + func, args, kwargs, future = work_item + try: + result = func(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + finally: + _nccl_work_queue.task_done() + except Exception as e: + logger.error(f"Error in NCCL worker thread: {e}") + if work_item and len(work_item) > 3: + work_item[3].set_exception(e) + + _nccl_worker_thread = Thread(target=nccl_worker, daemon=True, name="NCCLWorker") + _nccl_worker_thread.start() + logger.info("NCCL worker thread initialized") + + +def _submit_to_nccl_worker(func: Callable, *args, **kwargs) -> Any: + """Submit a function to the NCCL worker thread for execution. + + This ensures all non-/data/ endpoints (which may involve NCCL operations) + run in the same thread, maintaining NCCL compatibility while allowing + /data/ requests to be processed concurrently in other threads. + """ + global _nccl_work_queue + + _init_nccl_worker() + + future = Future() + _nccl_work_queue.put((func, args, kwargs, future)) + return future.result() # Block until result is available + + @app.route("/health", methods=["GET"]) def health_check(): """Health check endpoint to verify server is alive.""" @@ -34,7 +123,10 @@ def health_check(): @app.route("/configure", methods=["POST"]) def configure(): - """Configure worker with experiment config.""" + """Configure worker with experiment config. + + This endpoint is routed to the NCCL worker thread. + """ try: data = request.get_json() if data is None: @@ -55,16 +147,17 @@ def configure(): config = deserialize_value(config) config: BaseExperimentConfig - name_resolve.reconfigure(config.cluster.name_resolve) - seeding.set_random_seed(config.seed, key=f"{role}{rank}") - - return jsonify( - { + def execute_configure(): + name_resolve.reconfigure(config.cluster.name_resolve) + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + return { "status": "success", "message": "Worker configured successful.", "result": None, } - ) + + result = _submit_to_nccl_worker(execute_configure) + return jsonify(result) except Exception as e: logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 @@ -72,7 +165,10 @@ def configure(): @app.route("/set_env", methods=["POST"]) def set_env(): - """Set environment variables for the worker process.""" + """Set environment variables for the worker process. + + This endpoint is routed to the NCCL worker thread. + """ try: data = request.get_json() if data is None: @@ -84,7 +180,7 @@ def set_env(): if not isinstance(env_payload, dict): return jsonify({"error": "'env' must be a dictionary"}), 400 - for key, value in env_payload.items(): + for key in env_payload.keys(): if not isinstance(key, str): return ( jsonify( @@ -96,10 +192,15 @@ def set_env(): ), 400, ) - os.environ[key] = str(value) - logger.info(f"Set {key}={value}") - return jsonify({"status": "success"}) + def execute_set_env(): + for key, value in env_payload.items(): + os.environ[key] = str(value) + logger.info(f"Set {key}={value}") + return {"status": "success"} + + result = _submit_to_nccl_worker(execute_set_env) + return jsonify(result) except Exception as e: logger.error(f"Unexpected error in set_env: {e}\n{traceback.format_exc()}") @@ -111,6 +212,8 @@ def create_engine(): """ Create and initialize a TrainEngine or InferenceEngine instance on this worker. + This endpoint is routed to the NCCL worker thread. + Expected JSON payload: { "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path @@ -127,6 +230,7 @@ def create_engine(): global _engine try: + # Parse request in main thread (has Flask request context) data = request.get_json() if data is None: return jsonify({"error": "Invalid JSON in request body"}), 400 @@ -139,7 +243,7 @@ def create_engine(): if not engine_path: return jsonify({"error": "Missing 'engine' field in request"}), 400 - # Dynamic import + # Dynamic import (can be done in main thread) try: engine_class = import_from_string(engine_path) @@ -163,10 +267,21 @@ def create_engine(): logger.error(f"Invalid engine type: {e}") return jsonify({"error": str(e)}), 400 - # Instantiate engine + # Instantiate engine in NCCL worker thread (may involve NCCL initialization) + def create_engine_in_nccl_thread(): + """Create engine in NCCL worker thread.""" + try: + engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") + return engine + except Exception as e: + logger.error( + f"Failed to instantiate engine: {e}\n{traceback.format_exc()}" + ) + raise + try: - _engine = engine_class(*init_args, **init_kwargs) - logger.info(f"Engine '{engine_path}' instantiated successfully") + _engine = _submit_to_nccl_worker(create_engine_in_nccl_thread) return jsonify( { "status": "success", @@ -175,7 +290,6 @@ def create_engine(): } ) except Exception as e: - logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") return jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), 500 except Exception as e: @@ -190,6 +304,9 @@ def call_engine_method(): """ Call a method on the engine instance. + This endpoint is routed to the NCCL worker thread to ensure + all NCCL operations run in the same thread, preventing conflicts. + Expected JSON payload: { "method": "train_batch", @@ -217,72 +334,553 @@ def call_engine_method(): if not method_name: return jsonify({"error": "Missing 'method' field in request"}), 400 - # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) args = deserialize_value(args) kwargs = deserialize_value(kwargs) + should_broadcast = kwargs.pop("should_broadcast", True) + should_return_distributed_batch = kwargs.pop("return_distributed_batch", False) + result_key = kwargs.pop("result_key", None) + task_id = kwargs.pop("task_id", None) + + # Extract input batch metadata before resolving + input_metadata_list = [] + input_metadata_list.extend(_extract_input_batch_metadata(args)) + input_metadata_list.extend(_extract_input_batch_metadata(kwargs)) + assert len(input_metadata_list) <= 1, ( + "Only one input batch metadata is supported" + ) + input_batch_metadata = input_metadata_list[0] if input_metadata_list else None try: - should_bcast = kwargs.pop("_should_bcast", True) - if should_bcast and isinstance(_engine, TrainEngine): - logger.info(f"Broadcasting data for TrainEngine method: {method_name}") - - args = tensor_container_to(args, current_platform.current_device()) - args = broadcast_tensor_container( - args, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - kwargs = tensor_container_to(kwargs, current_platform.current_device()) - kwargs = broadcast_tensor_container( - kwargs, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - logger.info("Broadcasting data done.") + logger.info( + f"Resolving batch metadata for method '{method_name}', args: {args}, kwargs: {kwargs}" + ) + args = _resolve_batch_metadata(args) + kwargs = _resolve_batch_metadata(kwargs) except Exception as e: logger.error( - f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + f"Resolving batch metadata for method '{method_name}' failed: {e}\n{traceback.format_exc()}" ) return ( - jsonify({"error": f"Data broadcast '{method_name}' failed: {str(e)}"}), + jsonify( + {"error": f"Metadata resolution '{method_name}' failed: {str(e)}"} + ), 500, ) - # Call method directly - logger.info(f"Calling engine method: {method_name}") + def execute_in_nccl_thread(): + try: + if should_broadcast and isinstance(_engine, TrainEngine): + logger.info( + f"Broadcasting data for TrainEngine method: {method_name}" + ) + + args_bcast = tensor_container_to( + args, current_platform.current_device() + ) + args_bcast = broadcast_tensor_container( + args_bcast, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs_bcast = tensor_container_to( + kwargs, current_platform.current_device() + ) + kwargs_bcast = broadcast_tensor_container( + kwargs_bcast, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + else: + args_bcast = args + kwargs_bcast = kwargs + + logger.info(f"Calling engine method: {method_name}") + method = getattr(_engine, method_name) + result = method(*args_bcast, **kwargs_bcast) + + # Handle update weights future + if isinstance(result, Future): + logger.info("Waiting for update weights future") + result = result.result() + logger.info("Update weights future done") + + if should_return_distributed_batch: + result = _handle_distributed_batch_return( + result, + result_key, + task_id=task_id, + input_batch_metadata=input_batch_metadata, + ) + logger.debug("Handling distributed batch memory return") + + return result + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + raise ValueError(f"Engine does not have method '{method_name}'") + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + raise + + # Submit to NCCL worker thread try: - # Get the method - will raise AttributeError if it doesn't exist - method = getattr(_engine, method_name) - result = method(*args, **kwargs) - - # HACK: handle update weights future - if isinstance(result, Future): - logger.info("Waiting for update weights future") - result = result.result() - logger.info("Update weights future done") - - # Serialize result (convert tensors to SerializedTensor dicts) - serialized_result = serialize_value(result) - return jsonify({"status": "success", "result": serialized_result}) - - except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") + result = _submit_to_nccl_worker(execute_in_nccl_thread) + except Exception as e: + error_msg = str(e) + if "Engine does not have method" in error_msg: + return ( + jsonify({"error": error_msg}), + 400, + ) return ( - jsonify({"error": f"Engine does not have method '{method_name}'"}), - 400, + jsonify( + {"error": f"Engine method '{method_name}' failed: {error_msg}"} + ), + 500, + ) + + serialized_result = serialize_value(result) + return jsonify({"status": "success", "result": serialized_result}) + + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +# ==================== Batch Related Functions ==================== +def _ensure_tensor_on_cpu(value: Any) -> Any: + """Recursively ensure all tensors in a value are on CPU.""" + if isinstance(value, torch.Tensor): + # Move to CPU if on GPU, detach to avoid gradient tracking + if value.is_cuda: + return value.detach().cpu() + return value.detach() if value.requires_grad else value + elif isinstance(value, dict): + return {k: _ensure_tensor_on_cpu(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + converted = [_ensure_tensor_on_cpu(item) for item in value] + return type(value)(converted) + else: + return value + + +def _is_batch_metadata(data: Any) -> bool: + """Check if data is a DistributedBatchMemory metadata wrapper.""" + return isinstance(data, dict) and bool(data.get("__distributed_batch_metadata__")) + + +def _get_batch_metadata(data: Any) -> BatchMetadata | None: + """Extract BatchMetadata from a DistributedBatchMemory metadata wrapper.""" + if _is_batch_metadata(data): + return data.get("metadata") + return None + + +def _extract_input_batch_metadata(data: Any) -> list[BatchMetadata]: + """Extract all DistributedBatchMemory metadata from input data.""" + metadata_list = [] + + if isinstance(data, dict): + metadata = _get_batch_metadata(data) + if metadata is not None: + metadata_list.append(metadata) + else: + # Recursively check dict values + for v in data.values(): + metadata_list.extend(_extract_input_batch_metadata(v)) + elif isinstance(data, (list, tuple)): + # Recursively check list/tuple elements + for item in data: + metadata_list.extend(_extract_input_batch_metadata(item)) + + return metadata_list + + +async def _aresolve_batch_metadata(data: Any) -> Any: + """Resolve DistributedBatchMemory metadata to actual data. + + Recursively traverses data structures and replaces DistributedBatchMemory + metadata with actual data fetched from remote nodes. + """ + if isinstance(data, dict): + metadata = _get_batch_metadata(data) + if metadata is not None: + batch = DistributedBatchMemory.from_metadata(metadata) + return await batch.aget_data() + else: + # Recursively resolve dict values + resolved_items = await asyncio.gather( + *[_aresolve_batch_metadata(v) for v in data.values()] ) + return {k: v for k, v in zip(data.keys(), resolved_items)} + elif isinstance(data, (list, tuple)): + # Recursively resolve list/tuple elements + resolved_items = await asyncio.gather( + *[_aresolve_batch_metadata(item) for item in data] + ) + return type(data)(resolved_items) + else: + return data + + +def _resolve_batch_metadata(data: Any) -> Any: + """Resolve DistributedBatch metadata to actual data.""" + + def run_in_thread(): + """Run async resolution in a new thread with a new event loop.""" + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(_aresolve_batch_metadata(data)) except Exception as e: logger.error( - f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + f"Error resolving batch metadata in thread: {e}\n{traceback.format_exc()}" + ) + raise + finally: + new_loop.close() + + with ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result() + + +def _create_matched_batch_metadata( + data_to_store: dict[str, Tensor], + input_batch_metadata: BatchMetadata, + node_id: str, + node_addr: str, +) -> DistributedBatchMemory: + """Create batch metadata matching input structure with tensor splitting. + + This function creates shards matching the input metadata structure: + 1. Groups input shards by task_id + 2. For each task_id, creates output shards for each key in data_to_store + 3. Splits tensors along dimension 0 if sizes don't match + + Parameters + ---------- + data_to_store : dict[str, Tensor] + Result data to split and store (keys become shard_id.key) + input_batch_metadata : BatchMetadata + Input batch metadata to match + node_id : str + Current node identifier + node_addr : str + Current node address + + Returns + ------- + DistributedBatchMemory + Batch with metadata matching input structure + """ + global _batch_storage, _batch_storage_lock, _batch_storage_stats + + # Group input shards by task_id and calculate offsets + task_id_to_shards: dict[str, list[ShardMetadata]] = {} + task_id_to_size: dict[str, int] = {} + task_id_offsets: dict[str, int] = {} + + current_offset = 0 + for shard in input_batch_metadata.shards: + task_id = shard.shard_id.task_id + if task_id not in task_id_to_shards: + task_id_to_shards[task_id] = [] + task_id_offsets[task_id] = current_offset + task_id_to_size[task_id] = 0 + + task_id_to_shards[task_id].append(shard) + shard_size = shard.tensor_metadata.shape[0] + # Only count the first shard of each task_id for offset calculation + if len(task_id_to_shards[task_id]) == 1: + task_id_to_size[task_id] = shard_size + current_offset += shard_size + + # Get result keys (these will become the new shard_id.key values) + result_keys = list(data_to_store.keys()) + if not result_keys: + raise ValueError("data_to_store must contain at least one tensor") + + # Create output shards + output_shards = [] + + for task_id, input_shards in task_id_to_shards.items(): + total_size = task_id_to_size[task_id] + offset = task_id_offsets[task_id] + + # For each result key, create a shard with the same task_id + for result_key in result_keys: + tensor = data_to_store[result_key] + + sliced_tensor = tensor[offset : offset + total_size].clone().cpu() + shard_id = ShardId(task_id=task_id, key=result_key) + with _batch_storage_lock: + _batch_storage[shard_id] = sliced_tensor + serialized_data = serialize_value(sliced_tensor) + data_bytes = orjson.dumps(serialized_data) + _batch_storage_stats[shard_id] = len(data_bytes) + + logger.debug( + f"Stored shard {shard_id} (size={len(data_bytes)} bytes, " + f"batch_size={total_size}, node_addr={node_addr})" + ) + + tensor_metadata = TensorMetadata( + shape=tuple(sliced_tensor.shape), + dtype=str(sliced_tensor.dtype), + device=str(sliced_tensor.device), + ) + + output_shard = ShardMetadata( + node_id=node_id, + node_addr=node_addr, + shard_id=shard_id, + tensor_metadata=tensor_metadata, + ) + output_shards.append(output_shard) + + batch_metadata = BatchMetadata( + batch_id=str(uuid.uuid4()), + shards=output_shards, + ) + + batch = DistributedBatchMemory.from_metadata(batch_metadata) + logger.info( + f"Created DistributedBatchMemory matching input: {batch_metadata.batch_id}, " + f"num_shards={len(output_shards)}, shard_ids={[s.shard_id for s in output_shards]}" + ) + + return batch + + +def _handle_distributed_batch_return( + result: Any, + result_key: str | None, + task_id: str | None = None, + input_batch_metadata: BatchMetadata | None = None, +) -> Any: + """Handle distributed batch memory return. + + When the return value is one of the following types, automatically write to + local `_batch_storage` and return `DistributedBatchMemory` (or its list) + metadata: + + - ``torch.Tensor`` + - ``dict[str, torch.Tensor]`` + - ``list[dict[str, torch.Tensor]]`` + """ + global _batch_storage, _batch_storage_lock, _batch_storage_stats + + # Handle list: recursively process each element + if isinstance(result, list): + return [ + _handle_distributed_batch_return( + r, + result_key, + task_id=task_id, + input_batch_metadata=input_batch_metadata, ) + for r in result + ] + + # Check if result is Tensor or dict[str, Tensor] + data_to_store = None + if isinstance(result, torch.Tensor): + if result_key is None: + result_key = "default_key" + data_to_store = {result_key: result} + elif isinstance(result, dict) and all( + isinstance(v, torch.Tensor) for v in result.values() + ): + data_to_store = result + + if data_to_store is None: + return result + + # Get node info + node_id = os.environ.get("HOSTNAME", "unknown") + rank = int(os.environ.get("RANK", "0")) + node_id = f"{node_id}_rank{rank}" + + # Get node address + global _server_host, _server_port + node_addr = f"{_server_host}:{_server_port}" + + # If input_batch_metadata is not provided, create a fake one + if input_batch_metadata is None or not input_batch_metadata.shards: + task_id = task_id or str(uuid.uuid4()) + + first_tensor = next( + (v for v in data_to_store.values() if isinstance(v, torch.Tensor)), None + ) + if first_tensor is None: + return result + + batch_size = first_tensor.shape[0] if len(first_tensor.shape) > 0 else 1 + fake_shard_id = ShardId(task_id=task_id, key="__dummy_input__") + fake_tensor_metadata = TensorMetadata( + shape=(batch_size,), + dtype=str(first_tensor.dtype), + device=str(first_tensor.device), + ) + fake_shard = ShardMetadata( + node_id=node_id, + node_addr=node_addr, + shard_id=fake_shard_id, + tensor_metadata=fake_tensor_metadata, + ) + + input_batch_metadata = BatchMetadata( + batch_id=str(uuid.uuid4()), + shards=[fake_shard], + ) + + return _create_matched_batch_metadata( + data_to_store, + input_batch_metadata, + node_id, + node_addr, + ) + + +# ==================== Batch Data Storage Endpoints ==================== +@app.route("/data/", methods=["PUT"]) +def store_batch_data(shard_id: str): + """Store batch data shard.""" + global _batch_storage, _batch_storage_lock, _batch_storage_stats + + try: + # Convert string shard_id to ShardId + shard_id_obj = ShardId.from_string(shard_id) + + data_bytes = request.get_data() + # Deserialize from orjson, then deserialize_value to restore tensors + serialized_data = orjson.loads(data_bytes) + data = deserialize_value(serialized_data) + + cpu_data = _ensure_tensor_on_cpu(data) + + with _batch_storage_lock: + _batch_storage[shard_id_obj] = cpu_data + _batch_storage_stats[shard_id_obj] = len(data_bytes) + + logger.debug( + f"Stored batch shard {shard_id_obj} (size={len(data_bytes)} bytes)" + ) + # Echo back the original shard_id string to avoid adding default keys + return jsonify({"status": "ok", "shard_id": shard_id}) + + except Exception as e: + logger.error(f"Error storing batch shard {shard_id_obj}: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@app.route("/data/", methods=["GET"]) +def retrieve_batch_data(shard_id: str): + """Retrieve batch data shard.""" + global _batch_storage, _batch_storage_lock + + # Convert string shard_id to ShardId + shard_id_obj = ShardId.from_string(shard_id) + + logger.debug(f"Received data get request for shard {shard_id_obj}") + try: + with _batch_storage_lock: + if shard_id_obj not in _batch_storage: + return ( + jsonify( + { + "status": "error", + "message": f"Shard {shard_id_obj} not found", + } + ), + 404, + ) + + data = _batch_storage[shard_id_obj] + + serialized_data = serialize_value(data) + data_bytes = orjson.dumps(serialized_data) + + logger.info( + f"Retrieved batch shard {shard_id_obj} (size={len(data_bytes)} bytes)" + ) + return Response(data_bytes, mimetype="application/octet-stream") + + except Exception as e: + logger.error(f"Error retrieving batch shard {shard_id_obj}: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@app.route("/data/clear", methods=["DELETE"]) +def clear_batch_data(): + """Clear specified batch data shards. + + Expected JSON payload: + { + "shard_ids": ["id1", "id2", ...] + } + """ + global _batch_storage, _batch_storage_lock, _batch_storage_stats + + try: + data = request.get_json(silent=True) or {} + shard_ids = data.get("shard_ids", []) + if not isinstance(shard_ids, list): return ( - jsonify({"error": f"Engine method '{method_name}' failed: {str(e)}"}), - 500, + jsonify({"status": "error", "message": "'shard_ids' must be a list"}), + 400, ) + shard_id_objs = [ + ShardId.from_string(sid) for sid in shard_ids if isinstance(sid, str) + ] + + if not shard_id_objs: + return jsonify({"status": "ok", "cleared_count": 0}) + + with _batch_storage_lock: + cleared_count = 0 + for shard_id_obj in shard_id_objs: + if shard_id_obj in _batch_storage: + del _batch_storage[shard_id_obj] + _batch_storage_stats.pop(shard_id_obj, None) + cleared_count += 1 + + logger.info( + f"Cleared {cleared_count} batch shards: {[str(sid) for sid in shard_id_objs]}" + ) + return jsonify({"status": "ok", "cleared_count": cleared_count}) except Exception as e: - logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + logger.error(f"Error clearing batch data: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +@app.route("/data/stats", methods=["GET"]) +def batch_data_stats(): + """Get batch data storage statistics.""" + global _batch_storage, _batch_storage_lock, _batch_storage_stats + + try: + with _batch_storage_lock: + total_shards = len(_batch_storage) + total_size = sum(_batch_storage_stats.values()) + + return jsonify( + { + "status": "ok", + "total_shards": total_shards, + "total_size_bytes": total_size, + } + ) + except Exception as e: + logger.error(f"Error getting batch data stats: {e}") + return jsonify({"status": "error", "message": str(e)}), 500 + + +# ==================== Cleanup ==================== def cleanup_engine(): @@ -297,6 +895,33 @@ def cleanup_engine(): _engine = None +def cleanup_batch_storage(): + """Clean up batch storage on shutdown.""" + global _batch_storage, _batch_storage_lock, _batch_storage_stats + with _batch_storage_lock: + _batch_storage.clear() + _batch_storage_stats.clear() + logger.info("Batch storage cleared") + + +def cleanup_nccl_worker(): + """Clean up NCCL worker thread.""" + global _nccl_worker_thread, _nccl_work_queue + + with _nccl_worker_lock: + if _nccl_work_queue is not None: + # Send shutdown signal + _nccl_work_queue.put(None) + _nccl_work_queue = None + + if _nccl_worker_thread is not None: + _nccl_worker_thread.join(timeout=5.0) + if _nccl_worker_thread.is_alive(): + logger.warning("NCCL worker thread did not shut down gracefully") + _nccl_worker_thread = None + logger.info("NCCL worker thread cleaned up") + + def main(): """Main entry point for the sync RPC server.""" parser = argparse.ArgumentParser( @@ -317,21 +942,28 @@ def main(): args, _ = parser.parse_known_args() # Configure Werkzeug logging - import logging as stdlib_logging - werkzeug_logger = stdlib_logging.getLogger("werkzeug") werkzeug_logger.setLevel(getattr(stdlib_logging, args.werkzeug_log_level)) + # Set global server address variables + global _server_host, _server_port + _server_host = args.host + if _server_host == "0.0.0.0": + _server_host = socket.gethostbyname(socket.gethostname()) + _server_port = args.port + logger.info(f"Starting sync RPC server on {args.host}:{args.port}") logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") - # Run Flask app with single-threaded synchronous mode - # threaded=False ensures NCCL compatibility + # Run Flask app with multi-threaded mode + # /data/ endpoints are processed in request threads (concurrent) + # /call and other non-/data/ endpoints are routed to NCCL worker thread + # This ensures NCCL compatibility while allowing /data/ requests to be processed concurrently try: app.run( host=args.host, port=args.port, - threaded=False, # Single-threaded synchronous execution + threaded=True, # Multi-threaded for concurrent /data/ request handling processes=1, # Single process debug=False, use_reloader=False, @@ -339,7 +971,9 @@ def main(): except KeyboardInterrupt: logger.info("Shutting down sync RPC server") finally: + cleanup_nccl_worker() cleanup_engine() + cleanup_batch_storage() if __name__ == "__main__": diff --git a/areal/scheduler/rpc/serialization.py b/areal/scheduler/rpc/serialization.py index 939f0091f..39354e025 100644 --- a/areal/scheduler/rpc/serialization.py +++ b/areal/scheduler/rpc/serialization.py @@ -26,6 +26,7 @@ import torch from pydantic import BaseModel, Field +from areal.controller.batch import DistributedBatchMemory from areal.utils import logging TOKENIZER_ARCHIVE_INLINE_THRESHOLD = 512 * 1024 @@ -404,6 +405,17 @@ def serialize_value(value: Any) -> Any: if isinstance(value, np.ndarray): return SerializedNDArray.from_array(value).model_dump() + # Handle DistributedBatchMemory (check before dataclass) + if isinstance(value, DistributedBatchMemory): + # Use __getstate__ to get serializable state + state = value.__getstate__() + # Recursively serialize the state + serialized_state = serialize_value(state) + return { + "type": "distributed_batch_memory", + "state": serialized_state, + } + # Handle dataclass instances (check before dict, as dataclasses can be dict-like) # Note: is_dataclass returns True for both classes and instances, so check it's not a type if is_dataclass(value) and not isinstance(value, type): @@ -468,8 +480,17 @@ def deserialize_value(value: Any) -> Any: if value is None: return None - # Handle dict - check if it's a SerializedDataclass or SerializedTensor + # Handle dict - check if it's a SerializedDataclass, SerializedTensor, or DistributedBatchMemory if isinstance(value, dict): + # Check for DistributedBatchMemory marker + if value.get("type") == "distributed_batch_memory": + # Deserialize the state + state = deserialize_value(value.get("state", {})) + # Create instance and restore state + instance = DistributedBatchMemory.__new__(DistributedBatchMemory) + instance.__setstate__(state) + return instance + # Check for SerializedDataclass marker (check before tensor) if value.get("type") == "dataclass": try: diff --git a/areal/tests/test_batch.py b/areal/tests/test_batch.py index 561dc5782..f1253089c 100644 --- a/areal/tests/test_batch.py +++ b/areal/tests/test_batch.py @@ -4,6 +4,12 @@ import torch from areal.controller.batch import DistributedBatchMemory +from areal.controller.batch_metadata import ( + BatchMetadata, + ShardId, + ShardMetadata, + TensorMetadata, +) from areal.utils.batch_utils import ( convert_dict_to_list, convert_list_to_dict, @@ -145,7 +151,7 @@ def test_union(self): batch1 = DistributedBatchMemory.from_dict(data1) batch2 = DistributedBatchMemory.from_dict(data2) - merged = batch1.union(batch2) + merged = batch1.union_(batch2) assert len(merged) == 2 assert torch.equal(merged.dataset["input_ids"], torch.tensor([[1, 2], [3, 4]])) @@ -248,7 +254,7 @@ def test_str_repr(self, test_data): memory = DistributedBatchMemory.from_dict(test_data) str_repr = str(memory) assert "DistributedBatchMemory" in str_repr - assert "total_size=2" in str_repr + assert "size=2" in str_repr def test_empty_dataset(self): """Test handling of empty datasets.""" @@ -432,7 +438,7 @@ def test_union_with_prompt_task_type(self, prompt_task_data): batch1 = DistributedBatchMemory.from_list(data1) batch2 = DistributedBatchMemory.from_list(data2) - merged = batch1.union(batch2) + merged = batch1.union_(batch2) assert len(merged) == 12 assert merged.dataset["prompt"] == [item["prompt"] for item in prompt_task_data] @@ -535,7 +541,7 @@ def test_str_repr_with_prompt_task_type(self, prompt_task_data): str_repr = str(memory) assert "DistributedBatchMemory" in str_repr - assert "total_size=12" in str_repr + assert "size=12" in str_repr assert "prompt" in str_repr assert "task_type" in str_repr @@ -756,7 +762,7 @@ def test_union_scalar(self): batch1 = DistributedBatchMemory.from_dict(data1) batch2 = DistributedBatchMemory.from_dict(data2) - merged = batch1.union(batch2) + merged = batch1.union_(batch2) assert len(merged) == 4 assert merged.dataset["labels"] == [1, 2, 3, 4] assert merged.dataset["scores"] == [0.1, 0.2, 0.3, 0.4] @@ -777,7 +783,7 @@ def test_union_mixed(self): batch1 = DistributedBatchMemory.from_dict(data1) batch2 = DistributedBatchMemory.from_dict(data2) - merged = batch1.union(batch2) + merged = batch1.union_(batch2) assert len(merged) == 4 assert merged.dataset["labels"] == [1, 0, 0, 1] @@ -905,7 +911,7 @@ def test_str_repr(self, test_data): memory = DistributedBatchMemory.from_dict(test_data) str_repr = str(memory) assert "DistributedBatchMemory" in str_repr - assert "total_size=2" in str_repr + assert "size=2" in str_repr def test_empty_dataset(self): """Test handling of empty datasets.""" @@ -1012,7 +1018,7 @@ def test_2d_list_operations(self): batch1_union = DistributedBatchMemory.from_dict(data1) batch2_union = DistributedBatchMemory.from_dict(data2) - merged = batch1_union.union(batch2_union) + merged = batch1_union.union_(batch2_union) assert len(merged) == 4 assert merged["sequences"] == [[1, 2], [3, 4], [5, 6], [7, 8]] assert merged["labels"] == [1, 0, 1, 0] @@ -1235,7 +1241,7 @@ def test_union_with_rollout_res(self, rollout_res_data): batch1 = DistributedBatchMemory.from_dict(data1) batch2 = DistributedBatchMemory.from_dict(data2) - merged = batch1.union(batch2) + merged = batch1.union_(batch2) assert len(merged) == 512 for key in rollout_res_data.keys(): @@ -1295,7 +1301,7 @@ def test_str_repr_with_rollout_res(self, rollout_res_data): str_repr = str(memory) assert "DistributedBatchMemory" in str_repr - assert "total_size=512" in str_repr + assert "size=512" in str_repr assert "attention_mask" in str_repr assert "input_ids" in str_repr @@ -1377,3 +1383,735 @@ def test_single_scalar_value(self): assert len(memory) == 1 assert memory["single_label"] == 42 assert memory["single_score"] == 0.99 + + +# ============================================================================= +# TestDistributedBatchMemoryMetadata +# ============================================================================= + + +class TestDistributedBatchMemoryMetadata: + """Test metadata-based functionality.""" + + def test_from_metadata(self): + """Test creating batch from metadata.""" + metadata = BatchMetadata( + batch_id="test-batch", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(10, 5), + dtype="torch.int64", + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + assert batch.dataset is None + assert batch.metadata == metadata + + def test_concat_with_metadata(self): + """Test concatenating batches with metadata.""" + metadata1 = BatchMetadata( + batch_id="batch-1", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(5, 10), + dtype="torch.int64", + ), + ), + ], + ) + metadata2 = BatchMetadata( + batch_id="batch-2", + shards=[ + ShardMetadata( + node_id="node-1", + node_addr="localhost:8766", + shard_id=ShardId(task_id="shard-1", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(3, 10), + dtype="torch.int64", + ), + ), + ], + ) + + batch1 = DistributedBatchMemory.from_metadata(metadata1) + batch2 = DistributedBatchMemory.from_metadata(metadata2) + result = DistributedBatchMemory.concat([batch1, batch2]) + + assert result.metadata is not None + assert len(result) == 8 # 5 + 3 + assert len(result.metadata.shards) == 2 + + def test_serialization_metadata(self): + """Test serialization and deserialization with metadata.""" + metadata = BatchMetadata( + batch_id="test", + shards=[], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + # Serialize + serialized = pickle.dumps(batch) + # Deserialize + deserialized = pickle.loads(serialized) + + assert deserialized.metadata.batch_id == "test" + assert deserialized.metadata is not None + assert deserialized.dataset is None + + def test_chunk_metadata(self): + """Test chunking with metadata mode.""" + metadata = BatchMetadata( + batch_id="test-batch", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(50, 128), dtype="torch.int64" + ), + ), + ShardMetadata( + node_id="node-1", + node_addr="localhost:8766", + shard_id=ShardId(task_id="shard-1", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(50, 128), dtype="torch.int64" + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + # Chunk into 2 groups + chunks = batch.chunk(2) + + assert len(chunks) == 2 + assert all(chunk.metadata is not None for chunk in chunks) + assert all(chunk.dataset is None for chunk in chunks) + # Total size should be preserved + total_size = sum(len(chunk) for chunk in chunks) + assert total_size == 100 + + def test_union_metadata(self): + """Test union with metadata mode.""" + metadata1 = BatchMetadata( + batch_id="batch-1", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(30, 10), + dtype="torch.int64", + ), + ), + ], + ) + metadata2 = BatchMetadata( + batch_id="batch-2", + shards=[ + ShardMetadata( + node_id="node-1", + node_addr="localhost:8766", + shard_id=ShardId(task_id="shard-1", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(20, 10), + dtype="torch.int64", + ), + ), + ], + ) + + batch1 = DistributedBatchMemory.from_metadata(metadata1) + batch2 = DistributedBatchMemory.from_metadata(metadata2) + result = batch1.union_(batch2) + + assert result.metadata is not None + assert len(result) == 50 # 30 + 20 + assert len(result.metadata.shards) == 2 + assert result.dataset is None + + def test_get_total_size_metadata(self): + """Test _get_total_size with metadata mode.""" + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(123, 10), + dtype="torch.int64", + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + assert len(batch) == 123 + assert batch._get_total_size() == 123 + + +# ============================================================================= +# TestBatchMetadata +# ============================================================================= + + +class TestBatchMetadata: + """Test metadata structures.""" + + def test_tensor_metadata(self): + """Test TensorMetadata creation.""" + meta = TensorMetadata( + shape=(32, 128), + dtype="torch.float32", + device="cuda:0", + ) + assert meta.shape == (32, 128) + assert meta.dtype == "torch.float32" + assert meta.device == "cuda:0" + + def test_shard_metadata(self): + """Test ShardMetadata creation.""" + meta = ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(32, 128), + dtype="torch.int64", + ), + ) + assert meta.node_id == "node-0" + # Batch size can be inferred from first field's shape[0] + assert meta.tensor_metadata is not None + assert meta.tensor_metadata.shape[0] == 32 + assert meta.shard_id.key == "input_ids" + + def test_batch_metadata_node_addrs(self): + """Test getting all node addresses from batch metadata.""" + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="192.168.1.10:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(32, 10), + dtype="torch.int64", + ), + ), + ShardMetadata( + node_id="node-1", + node_addr="192.168.1.11:8765", + shard_id=ShardId(task_id="shard-1", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(32, 10), + dtype="torch.int64", + ), + ), + ], + ) + addrs = metadata.get_all_node_addrs() + assert len(addrs) == 2 + assert "192.168.1.10:8765" in addrs + assert "192.168.1.11:8765" in addrs + + +# ============================================================================= +# TestRPCDistributedBatchReturn +# ============================================================================= + + +class TestRPCDistributedBatchReturn: + """Test RPC automatic distributed batch return functionality.""" + + def test_handle_tensor_return(self): + """Test handling tensor return from engine method.""" + from areal.scheduler.rpc.rpc_server import _handle_distributed_batch_return + + # Test tensor return + tensor_result = torch.randn(10, 5) + batch = _handle_distributed_batch_return( + tensor_result, + result_key="logits", + ) + + # Should return DistributedBatchMemory with metadata + assert isinstance(batch, DistributedBatchMemory) + assert batch.metadata is not None + assert len(batch) == 10 + assert len(batch.metadata.shards) == 1 + assert batch.metadata.shards[0].tensor_metadata is not None + assert batch.metadata.shards[0].shard_id.key == "logits" + + def test_handle_dict_return(self): + """Test handling dict return from engine method.""" + from areal.scheduler.rpc.rpc_server import _handle_distributed_batch_return + + # Test dict return + dict_result = { + "logits": torch.randn(8, 10, 50), + "values": torch.randn(8, 10), # non-tensor field + } + batch = _handle_distributed_batch_return( + dict_result, + result_key=None, + ) + + # Should return DistributedBatchMemory with metadata + assert isinstance(batch, DistributedBatchMemory) + assert batch.metadata is not None + assert len(batch) == 8 + assert batch.metadata.shards[0].tensor_metadata is not None + assert batch.metadata.shards[0].shard_id.key == "logits" + # Note: With single tensor_metadata per shard, we can only check one field + # The shard_id.key should match one of the tensor keys + assert batch.metadata.shards[0].shard_id.key in ["logits", "values"] + + def test_handle_non_tensor_return(self): + """Test that non-tensor returns are passed through.""" + from areal.scheduler.rpc.rpc_server import _handle_distributed_batch_return + + class MockEngine: + def get_version(self): + return 0 + + engine = MockEngine() + + # Test non-tensor returns + int_result = 42 + result = _handle_distributed_batch_return(int_result, None, engine) + assert result == 42 + + str_result = "hello" + result = _handle_distributed_batch_return(str_result, None, engine) + assert result == "hello" + + dict_result = {"loss": 0.5, "accuracy": 0.9} # no tensors + result = _handle_distributed_batch_return(dict_result, None, engine) + assert result == dict_result + + def test_handle_distributed_batch_with_input_metadata(self): + """Test _handle_distributed_batch_return with input_batch_metadata. + + Input shards: + - shard_id: A, key: input_ids -> tensor + - shard_id: A, key: attention_mask -> tensor + - shard_id: B, key: input_ids -> tensor + - shard_id: B, key: attention_mask -> tensor + - shard_id: C, key: input_ids -> tensor + - shard_id: C, key: attention_mask -> tensor + + Expected output shards: + - shard_id: A, key: new_key1 -> tensor + - shard_id: A, key: new_key2 -> tensor + - shard_id: B, key: new_key1 -> tensor + - shard_id: B, key: new_key2 -> tensor + - shard_id: C, key: new_key1 -> tensor + - shard_id: C, key: new_key2 -> tensor + """ + from areal.scheduler.rpc.rpc_server import _handle_distributed_batch_return + + # Create input batch metadata with shards grouped by task_id + input_shards = [ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="A", key="input_ids"), + tensor_metadata=TensorMetadata(shape=(10, 128), dtype="torch.int64"), + ), + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="A", key="attention_mask"), + tensor_metadata=TensorMetadata(shape=(10, 128), dtype="torch.int64"), + ), + ShardMetadata( + node_id="node-1", + node_addr="localhost:8766", + shard_id=ShardId(task_id="B", key="input_ids"), + tensor_metadata=TensorMetadata(shape=(20, 128), dtype="torch.int64"), + ), + ShardMetadata( + node_id="node-1", + node_addr="localhost:8766", + shard_id=ShardId(task_id="B", key="attention_mask"), + tensor_metadata=TensorMetadata(shape=(20, 128), dtype="torch.int64"), + ), + ShardMetadata( + node_id="node-2", + node_addr="localhost:8767", + shard_id=ShardId(task_id="C", key="input_ids"), + tensor_metadata=TensorMetadata(shape=(15, 128), dtype="torch.int64"), + ), + ShardMetadata( + node_id="node-2", + node_addr="localhost:8767", + shard_id=ShardId(task_id="C", key="attention_mask"), + tensor_metadata=TensorMetadata(shape=(15, 128), dtype="torch.int64"), + ), + ] + input_batch_metadata = BatchMetadata( + batch_id="input-batch", + shards=input_shards, + ) + + # Create result with new keys + # Total batch size: 10 + 20 + 15 = 45 + result = { + "new_key1": torch.randn(45, 256), # batch_size=45 + "new_key2": torch.randn(45, 128), # batch_size=45 + } + + batch = _handle_distributed_batch_return( + result, + result_key=None, + input_batch_metadata=input_batch_metadata, + ) + + # Verify output structure + assert isinstance(batch, DistributedBatchMemory) + assert batch.metadata is not None + assert len(batch.metadata.shards) == 6 # 3 task_ids * 2 keys + + # Group output shards by task_id + output_by_task = {} + for shard in batch.metadata.shards: + task_id = shard.shard_id.task_id + if task_id not in output_by_task: + output_by_task[task_id] = [] + output_by_task[task_id].append(shard) + + # Verify each task_id has 2 shards (one for each result key) + assert len(output_by_task["A"]) == 2 + assert len(output_by_task["B"]) == 2 + assert len(output_by_task["C"]) == 2 + + # Verify shard keys match result keys + for task_id, shards in output_by_task.items(): + shard_keys = {s.shard_id.key for s in shards} + assert shard_keys == {"new_key1", "new_key2"} + + # Verify tensor shapes + # Task A: batch_size=10 + a_new_key1 = next( + s for s in output_by_task["A"] if s.shard_id.key == "new_key1" + ) + assert a_new_key1.tensor_metadata.shape == (10, 256) + a_new_key2 = next( + s for s in output_by_task["A"] if s.shard_id.key == "new_key2" + ) + assert a_new_key2.tensor_metadata.shape == (10, 128) + + # Task B: batch_size=20 + b_new_key1 = next( + s for s in output_by_task["B"] if s.shard_id.key == "new_key1" + ) + assert b_new_key1.tensor_metadata.shape == (20, 256) + b_new_key2 = next( + s for s in output_by_task["B"] if s.shard_id.key == "new_key2" + ) + assert b_new_key2.tensor_metadata.shape == (20, 128) + + # Task C: batch_size=15 + c_new_key1 = next( + s for s in output_by_task["C"] if s.shard_id.key == "new_key1" + ) + assert c_new_key1.tensor_metadata.shape == (15, 256) + c_new_key2 = next( + s for s in output_by_task["C"] if s.shard_id.key == "new_key2" + ) + assert c_new_key2.tensor_metadata.shape == (15, 128) + + def test_handle_distributed_batch_without_input_metadata(self): + """Test _handle_distributed_batch_return without input_batch_metadata. + + Expected behavior: + - Generate shards from result keys + - shard_id.task_id from input task_id or generate uuid + """ + from areal.scheduler.rpc.rpc_server import _handle_distributed_batch_return + + # Test with provided task_id + result = { + "logits": torch.randn(10, 50), + "values": torch.randn(10, 1), + } + task_id = "test-task-123" + + batch = _handle_distributed_batch_return( + result, + result_key=None, + task_id=task_id, + input_batch_metadata=None, + ) + + assert isinstance(batch, DistributedBatchMemory) + assert batch.metadata is not None + assert len(batch.metadata.shards) == 2 # One for each key + + # Verify shard_ids + shard_keys = {s.shard_id.key for s in batch.metadata.shards} + assert shard_keys == {"logits", "values"} + + # Verify all shards have the same task_id + task_ids = {s.shard_id.task_id for s in batch.metadata.shards} + assert task_ids == {task_id} + + # Verify tensor shapes + logits_shard = next( + s for s in batch.metadata.shards if s.shard_id.key == "logits" + ) + assert logits_shard.tensor_metadata.shape == (10, 50) + values_shard = next( + s for s in batch.metadata.shards if s.shard_id.key == "values" + ) + assert values_shard.tensor_metadata.shape == (10, 1) + + # Test without task_id (should generate uuid) + batch2 = _handle_distributed_batch_return( + result, + result_key=None, + task_id=None, + input_batch_metadata=None, + ) + + assert isinstance(batch2, DistributedBatchMemory) + assert batch2.metadata is not None + assert len(batch2.metadata.shards) == 2 + + # Verify all shards have the same generated task_id + task_ids2 = {s.shard_id.task_id for s in batch2.metadata.shards} + assert len(task_ids2) == 1 # All shards should have same task_id + assert task_ids2 != {task_id} # Should be different from previous task_id + + +# ============================================================================= +# TestDistributedBatchMemoryExtended +# ============================================================================= + + +class TestDistributedBatchMemoryExtended: + """Extended tests for DistributedBatchMemory covering all methods.""" + + def test_get_client(self): + """Test getting or creating the shared client.""" + client1 = DistributedBatchMemory.get_client() + client2 = DistributedBatchMemory.get_client() + + assert client1 is client2 # Should be the same instance + assert client1 is not None + + def test_len_metadata(self): + """Test __len__ with metadata.""" + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(42, 10), + dtype="torch.int64", + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + assert len(batch) == 42 + + def test_str_metadata(self): + """Test __str__ with metadata.""" + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(10, 10), + dtype="torch.int64", + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + s = str(batch) + assert "DistributedBatchMemory" in s + assert "test" in s + + def test_get_total_size_tensor(self): + """Test _get_total_size with tensor.""" + data = {"input_ids": torch.tensor([[1, 2], [3, 4], [5, 6]])} + batch = DistributedBatchMemory.from_dict(data) + + assert batch._get_total_size() == 3 + + def test_get_total_size_list(self): + """Test _get_total_size with list.""" + data = {"labels": [0, 1, 2, 3]} + batch = DistributedBatchMemory.from_dict(data) + + assert batch._get_total_size() == 4 + + def test_get_total_size_scalar(self): + """Test _get_total_size with scalar.""" + data = {"value": 42} + batch = DistributedBatchMemory.from_dict(data) + + assert batch._get_total_size() == 1 + + def test_get_total_size_empty(self): + """Test _get_total_size with empty dataset.""" + batch = DistributedBatchMemory.from_dict({}) + assert batch._get_total_size() == 0 + + def test_chunk_metadata_empty(self): + """Test chunking metadata batch with no metadata raises error.""" + batch = DistributedBatchMemory.__new__(DistributedBatchMemory) + batch.dataset = None + batch.metadata = None + + with pytest.raises(Exception): # FrameworkError + batch.chunk(2) + + def test_chunk_by_ffd_metadata_fallback(self): + """Test chunk_by_ffd falls back to chunk in metadata mode.""" + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(10, 10), + dtype="torch.int64", + ), + ), + ], + ) + batch = DistributedBatchMemory.from_metadata(metadata) + + chunks = batch.chunk_by_ffd(group_size=2, dp_size=2) + assert len(chunks) == 2 + assert all(chunk.metadata is not None for chunk in chunks) + + def test_union_mixed_mode_error(self): + """Test union raises error for mixed modes.""" + batch1 = DistributedBatchMemory.from_dict({"input_ids": torch.tensor([[1, 2]])}) + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(1, 2), + dtype="torch.int64", + ), + ), + ], + ) + batch2 = DistributedBatchMemory.from_metadata(metadata) + + with pytest.raises(Exception): # FrameworkError + batch1.union_(batch2) + + def test_concat_empty_list_error(self): + """Test concat with empty list raises error.""" + with pytest.raises(AssertionError): + DistributedBatchMemory.concat([]) + + def test_concat_different_keys_error(self): + """Test concat with batches having different keys raises error.""" + batch1 = DistributedBatchMemory.from_dict({"input_ids": torch.tensor([[1, 2]])}) + batch2 = DistributedBatchMemory.from_dict({"labels": torch.tensor([0])}) + + with pytest.raises(Exception): # FrameworkError + DistributedBatchMemory.concat([batch1, batch2]) + + def test_concat_mixed_modes_error(self): + """Test concat with mixed modes (one metadata, one local) raises error.""" + batch1 = DistributedBatchMemory.from_dict({"input_ids": torch.tensor([[1, 2]])}) + metadata = BatchMetadata( + batch_id="test", + shards=[ + ShardMetadata( + node_id="node-0", + node_addr="localhost:8765", + shard_id=ShardId(task_id="shard-0", key="input_ids"), + tensor_metadata=TensorMetadata( + shape=(1, 2), + dtype="torch.int64", + ), + ), + ], + ) + batch2 = DistributedBatchMemory.from_metadata(metadata) + + # concat should only work with all metadata or all local + with pytest.raises( + FrameworkError, match="Cannot concatenate batches with mixed statuses" + ): + DistributedBatchMemory.concat([batch1, batch2]) + + def test_chunk_preserves_order(self): + """Test that chunking preserves sample order.""" + data = { + "input_ids": torch.tensor([[i, i + 1] for i in range(8)]), + } + batch = DistributedBatchMemory.from_dict(data) + + chunks = batch.chunk(2) + assert len(chunks) == 2 + assert len(chunks[0]) == 4 + assert len(chunks[1]) == 4 + + # Verify order is preserved + chunk0_data = chunks[0].get_data() + chunk1_data = chunks[1].get_data() + assert chunk0_data["input_ids"][0, 0] == 0 + assert chunk1_data["input_ids"][0, 0] == 4 + + def test_union_preserves_all_keys(self): + """Test that union preserves all keys from both batches.""" + batch1 = DistributedBatchMemory.from_dict( + { + "input_ids": torch.tensor([[1, 2]]), + "key1": torch.tensor([0]), + } + ) + batch2 = DistributedBatchMemory.from_dict( + { + "input_ids": torch.tensor([[3, 4]]), + "key2": torch.tensor([1]), + } + ) + + result = batch1.union_(batch2) + assert "input_ids" in result.dataset + assert "key1" in result.dataset + assert "key2" in result.dataset diff --git a/examples/single-controller/gsm8k_grpo.py b/examples/single-controller/gsm8k_grpo.py index e5517ae8a..79bc9458c 100644 --- a/examples/single-controller/gsm8k_grpo.py +++ b/examples/single-controller/gsm8k_grpo.py @@ -164,24 +164,32 @@ def main(args): if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): - logp = actor.compute_logp(batch) - batch["prox_logp"] = logp + prox_logp = actor.compute_logp( + batch, + return_distributed_batch=True, + result_key="prox_logp", + ) + batch["prox_logp"] = prox_logp log_gpu_stats("recompute logp") if ref is not None: with stats_tracker.record_timing("ref_logp"): - batch["ref_logp"] = ref.compute_logp(batch) + ref_logp = ref.compute_logp( + batch, + return_distributed_batch=True, + result_key="ref_logp", + ) + batch["ref_logp"] = ref_logp log_gpu_stats("ref logp") with stats_tracker.record_timing("compute_advantage"): - batch = actor.compute_advantages(batch) + batch = actor.compute_advantages(batch, return_distributed_batch=True) log_gpu_stats("compute advantages") with stats_tracker.record_timing("train_step"): actor.ppo_update(batch) actor.step_lr_scheduler() log_gpu_stats("ppo update") - # pause inference for updating weights, save, and evaluation rollout.pause() @@ -205,6 +213,9 @@ def main(args): tokenizer=tokenizer, ) + with stats_tracker.record_timing("clear_batches"): + actor.clear_batches(batch) + # Upload statistics to the logger (e.g., wandb) stats_logger.commit(epoch, step, global_step, actor.export_stats()) diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index bd8dca63c..d16163869 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -138,6 +138,10 @@ def evaluate_fn(): evaluator.evaluate(evaluate_fn, epoch, step, global_step) stats_logger.commit(epoch, step, global_step, engine.export_stats()) + + with stats_tracker.record_timing("clear_batches"): + engine.clear_batches(global_step) + global_step += 1 finally: