Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
832 changes: 734 additions & 98 deletions areal/controller/batch.py

Large diffs are not rendered by default.

168 changes: 168 additions & 0 deletions areal/controller/batch_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""HTTP client for distributed batch memory retrieval."""

import asyncio
import io
import pickle
from typing import Any

import aiohttp

from areal.controller.batch_metadata import BatchMetadata, ShardMetadata
from areal.utils import logging

logger = logging.getLogger("BatchClient")

# Default connection limit for batch data fetching
DEFAULT_CONNECTION_LIMIT = 100


class BatchDataClient:
"""HTTP client for fetching distributed batch data."""

def __init__(
self, timeout: float = 300.0, connection_limit: int = DEFAULT_CONNECTION_LIMIT
):
"""Initialize the batch data client.

Parameters
----------
timeout : float
Request timeout in seconds
connection_limit : int
Maximum number of concurrent connections
"""
self.timeout = aiohttp.ClientTimeout(total=timeout)
self.connection_limit = connection_limit

async def fetch_shard(
self, session: aiohttp.ClientSession, shard: ShardMetadata
) -> dict[str, Any]:
"""Fetch a logical shard (sub-range) from a physical shard."""
url = f"http://{shard.node_addr}/data/{shard.shard_id}"
params = {}
if shard.offset > 0:
params["offset"] = shard.offset
if shard.batch_size > 0:
params["batch_size"] = shard.batch_size

try:
async with session.get(
url, params=params, timeout=self.timeout
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(
f"Failed to fetch shard {shard.shard_id} from {shard.node_addr}: "
f"HTTP {response.status} - {error_text}"
)

data_bytes = await response.read()
buffer = io.BytesIO(data_bytes)
data = pickle.load(buffer)

logger.debug(
f"Fetched logical shard {shard.shard_id} from {shard.node_addr} "
f"(offset={shard.offset}, batch_size={shard.batch_size}, "
f"{len(data_bytes)} bytes)"
)
return data

except asyncio.TimeoutError as e:
raise RuntimeError(
f"Timeout fetching shard {shard.shard_id} from {shard.node_addr}"
) from e
except Exception as e:
raise RuntimeError(
f"Error fetching shard {shard.shard_id} from {shard.node_addr}: {e}"
) from e

async def fetch_shards(self, metadata: BatchMetadata) -> list[dict[str, Any]]:
"""Fetch all shards for a batch and return raw shard data."""
if not metadata.shards:
return []

connector = aiohttp.TCPConnector(limit=self.connection_limit)
async with aiohttp.ClientSession(
timeout=self.timeout, connector=connector
) as session:
logger.info(
f"Fetching {len(metadata.shards)} shards for batch {metadata.batch_id}"
)
tasks = [self.fetch_shard(session, shard) for shard in metadata.shards]
shard_data_list = await asyncio.gather(*tasks)
return shard_data_list

async def store_shard(
self,
session: aiohttp.ClientSession,
node_addr: str,
shard_id: str,
global_step: int,
data: dict[str, Any],
) -> None:
"""Store a shard on a node."""
url = f"http://{node_addr}/data/{shard_id}?global_step={global_step}"

# Serialize data
buffer = io.BytesIO()
pickle.dump(data, buffer)
data_bytes = buffer.getvalue()

try:
async with session.put(
url, data=data_bytes, timeout=self.timeout
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(
f"Failed to store shard {shard_id} to {node_addr}: "
f"HTTP {response.status} - {error_text}"
)

logger.debug(
f"Stored shard {shard_id} to {node_addr} ({len(data_bytes)} bytes)"
)

except asyncio.TimeoutError as e:
raise RuntimeError(
f"Timeout storing shard {shard_id} to {node_addr}"
) from e
except Exception as e:
raise RuntimeError(
f"Error storing shard {shard_id} to {node_addr}: {e}"
) from e

async def clear_batches(self, node_addrs: set[str], global_step: int) -> None:
"""Clear old data on multiple nodes."""
connector = aiohttp.TCPConnector(limit=self.connection_limit)
async with aiohttp.ClientSession(
timeout=self.timeout, connector=connector
) as session:
tasks = [
self._clear_node(session, node_addr, global_step)
for node_addr in node_addrs
]
await asyncio.gather(*tasks, return_exceptions=True)

async def _clear_node(
self, session: aiohttp.ClientSession, node_addr: str, global_step: int
) -> None:
"""Clear old data on a single node."""
url = f"http://{node_addr}/data/clear?global_step={global_step}"

try:
async with session.delete(url, timeout=self.timeout) as response:
if response.status != 200:
error_text = await response.text()
logger.warning(
f"Failed to clear data on {node_addr}: "
f"HTTP {response.status} - {error_text}"
)
else:
result = await response.json()
logger.debug(
f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}"
)

except Exception as e:
logger.warning(f"Error clearing data on {node_addr}: {e}")
59 changes: 59 additions & 0 deletions areal/controller/batch_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from dataclasses import dataclass, field


@dataclass
class TensorMetadata:
"""Metadata for a tensor field."""

shape: tuple[int, ...]
dtype: str
device: str = "cpu"

def __repr__(self) -> str:
return f"TensorMetadata(shape={self.shape}, dtype={self.dtype}, device={self.device})"


@dataclass
class ShardMetadata:
"""Metadata for a single (sub-)shard stored on one node.
A logical batch can be composed of multiple shards, and a single physical
shard can be split into multiple logical sub-shards via offset and batch_size.
"""

node_id: str
node_addr: str
shard_id: str
batch_size: int
offset: int = 0
fields: dict[str, TensorMetadata] = field(default_factory=dict)

def __repr__(self) -> str:
return (
f"ShardMetadata(node_id={self.node_id}, node_addr={self.node_addr}, "
f"shard_id={self.shard_id}, offset={self.offset}, "
f"batch_size={self.batch_size}, fields={list(self.fields.keys())})"
)


@dataclass
class BatchMetadata:
"""Metadata for a distributed batch sharded across multiple nodes."""

batch_id: str
global_step: int
total_batch_size: int
shards: list[ShardMetadata] = field(default_factory=list)

def __repr__(self) -> str:
return (
f"BatchMetadata(batch_id={self.batch_id}, global_step={self.global_step}, "
f"total_batch_size={self.total_batch_size}, num_shards={len(self.shards)}, "
f"shards={self.shards})"
)

def get_all_node_addrs(self) -> set[str]:
"""Get all unique node addresses in this batch."""
return {shard.node_addr for shard in self.shards}
41 changes: 26 additions & 15 deletions areal/controller/rollout_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from areal.core.staleness_manager import StalenessManager
from areal.core.workflow_executor import BatchTaskDispatcher
from areal.utils import logging, perf_tracer
from areal.utils.data import concat_padded_tensors, cycle_dataloader
from areal.utils.data import cycle_dataloader
from areal.utils.dynamic_import import import_from_string
from areal.utils.perf_tracer import trace_perf

Expand All @@ -40,7 +40,7 @@ class _RemoteRolloutTaskInput:

@dataclass
class _RemoteRolloutResult:
trajectory: dict[str, Any]
trajectory: dict[str, Any] | DistributedBatchMemory
task_id: int | None = None


Expand Down Expand Up @@ -294,6 +294,7 @@ async def _submit_then_wait() -> _RemoteRolloutResult | None:
timeout=0.1, # A short time to prevent blocking other requests
raise_timeout=False,
http_timeout=self.config.request_timeout,
return_distributed_batch=True,
)

# TimeourError will be catched below
Expand Down Expand Up @@ -361,12 +362,13 @@ def submit(

def wait(
self, count: int, timeout: float | None = None, raise_timeout: bool = True
) -> list[dict[str, Any] | None]:
) -> list[dict[str, Any] | DistributedBatchMemory | None]:
# Delegate to dispatcher and extract trajectories
results = self.dispatcher.wait_results(count, timeout, raise_timeout)
# Log and trace
if self.config.enable_rollout_tracing:
logger.info("Rollout results are ready!")

return [r.trajectory if r is not None else None for r in results]

@trace_perf("rollout_controller.rollout_batch", category="scheduler")
Expand All @@ -376,7 +378,7 @@ def rollout_batch(
workflow: RolloutWorkflow | type[RolloutWorkflow] | str,
workflow_kwargs: dict[str, Any] | None = None,
should_accept_fn: str | None = None,
) -> dict[str, Any]:
) -> DistributedBatchMemory:
perf_tracer.instant(
"rollout_controller.rollout_batch",
category="scheduler",
Expand All @@ -390,12 +392,11 @@ def rollout_batch(
should_accept_fn=should_accept_fn,
)
results = self.wait(count=len(data))
# Concatenate into batch tensor format
batch = concat_padded_tensors([r for r in results if r is not None])
batches = [b for b in results if isinstance(b, DistributedBatchMemory)]
if not batches:
return DistributedBatchMemory.from_dict({})

# NOTE: DistributedBatchMemory.from_dict does nothing for now
# Just for sync with internal code
return DistributedBatchMemory.from_dict(batch)
return DistributedBatchMemory.concat(batches)

@trace_perf("rollout_controller.prepare_batch", category="scheduler")
def prepare_batch(
Expand All @@ -404,7 +405,7 @@ def prepare_batch(
workflow: RolloutWorkflow | type[RolloutWorkflow] | str,
workflow_kwargs: dict[str, Any] | None = None,
should_accept_fn: str | None = None,
):
) -> DistributedBatchMemory:
"""Prepare a batch with controlled staleness.

Continuously submits from dataloader and waits for results, ensuring at least
Expand Down Expand Up @@ -437,13 +438,15 @@ def task_input_generator():
self.data_generator, batch_size=dataloader.batch_size
)

# Extract trajectories and concatenate
# Extract trajectories
trajectories = [r.trajectory if r is not None else None for r in results]
batch = concat_padded_tensors([t for t in trajectories if t is not None])

# NOTE: DistributedBatchMemory.from_dict does nothing for now
# Just for sync with internal code
return DistributedBatchMemory.from_dict(batch)
# Filter out None and only keep DistributedBatchMemory instances
batches = [t for t in trajectories if isinstance(t, DistributedBatchMemory)]
if not batches:
return DistributedBatchMemory.from_dict({})

return DistributedBatchMemory.concat(batches)

async def agenerate(self, req: ModelRequest) -> ModelResponse:
"""Asynchronously generate a response for the given request.
Expand Down Expand Up @@ -540,3 +543,11 @@ def dispatcher(
def runner(self):
"""For backward compatibility. The runner is now owned by the dispatcher."""
return self.dispatcher.runner

# ==================== DISTRIBUTED BATCH RPC WRAPPERS ====================
def clear_batches(self, global_step: int):
"""Clear all data with step less than global_step"""
server_addrs = {
f"{worker.ip}:{worker.worker_ports[0]}" for worker in self.workers
}
DistributedBatchMemory.clear(global_step, server_addrs)
Loading