-
Notifications
You must be signed in to change notification settings - Fork 245
feat: implement distributed batch #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
daihaowz
wants to merge
31
commits into
main
Choose a base branch
from
fh/dist-batch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,460
−235
Open
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
19b07e2
feat: implement distributed batch
9980938
clean code
38200f4
clean code
5b1044e
bugfix
5c21580
fix
4c1dd09
update
015acc1
update
ea6cc03
update ut
cff658b
update ut
3a56eb1
update ut
c71982b
update ut
90c2fa9
update ut
a6554fb
fix
8205d01
Refactor batch memory modes and rename RPC arguments
rchardx 2d7a857
Refine batch string representation and docstrings
rchardx bb5aa1f
Refactor batch memory access and assignment logic
rchardx 3d5dccf
Enhance DistributedBatchMemory assignment and optimize HTTP client co…
rchardx a60773a
add global_step to DistributedBatchMemory.__str__
9b18dee
Merge branch 'main' into fh/dist-batch
daihaowz 50b0929
Merge branch 'main' into fh/dist-batch
daihaowz bb420a8
Merge branch 'main' into fh/dist-batch
daihaowz a98483a
Update serialize method for DistributedBatchMemory
02b1c87
rename BatchMode to BatchStatus
99a59a8
rename union to union_ && fix ut
ef32ec7
opt distributed batch: rm global_step field
607a981
ok
1689fa7
.
7b4a53b
.
82ccc1f
fix hang
28cbd35
opt
fc17d15
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fh/d…
garrett4wade File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| """HTTP client for distributed batch memory retrieval.""" | ||
|
|
||
| import asyncio | ||
| import io | ||
| import pickle | ||
| from typing import Any | ||
|
|
||
| import aiohttp | ||
|
|
||
| from areal.controller.batch_metadata import BatchMetadata, ShardMetadata | ||
| from areal.utils import logging | ||
|
|
||
| logger = logging.getLogger("BatchClient") | ||
|
|
||
| # Default connection limit for batch data fetching | ||
| DEFAULT_CONNECTION_LIMIT = 100 | ||
|
|
||
|
|
||
| class BatchDataClient: | ||
| """HTTP client for fetching distributed batch data.""" | ||
|
|
||
| def __init__( | ||
| self, timeout: float = 300.0, connection_limit: int = DEFAULT_CONNECTION_LIMIT | ||
| ): | ||
| """Initialize the batch data client. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| timeout : float | ||
| Request timeout in seconds | ||
| connection_limit : int | ||
| Maximum number of concurrent connections | ||
| """ | ||
| self.timeout = aiohttp.ClientTimeout(total=timeout) | ||
| self.connection_limit = connection_limit | ||
|
|
||
| async def fetch_shard( | ||
| self, session: aiohttp.ClientSession, shard: ShardMetadata | ||
| ) -> dict[str, Any]: | ||
| """Fetch a logical shard (sub-range) from a physical shard.""" | ||
| url = f"http://{shard.node_addr}/data/{shard.shard_id}" | ||
| params = {} | ||
| if shard.offset > 0: | ||
| params["offset"] = shard.offset | ||
| if shard.batch_size > 0: | ||
| params["batch_size"] = shard.batch_size | ||
|
|
||
| try: | ||
| async with session.get( | ||
| url, params=params, timeout=self.timeout | ||
| ) as response: | ||
| if response.status != 200: | ||
| error_text = await response.text() | ||
| raise RuntimeError( | ||
| f"Failed to fetch shard {shard.shard_id} from {shard.node_addr}: " | ||
| f"HTTP {response.status} - {error_text}" | ||
| ) | ||
|
|
||
| data_bytes = await response.read() | ||
| buffer = io.BytesIO(data_bytes) | ||
| data = pickle.load(buffer) | ||
|
|
||
| logger.debug( | ||
| f"Fetched logical shard {shard.shard_id} from {shard.node_addr} " | ||
| f"(offset={shard.offset}, batch_size={shard.batch_size}, " | ||
| f"{len(data_bytes)} bytes)" | ||
| ) | ||
| return data | ||
|
|
||
| except asyncio.TimeoutError as e: | ||
| raise RuntimeError( | ||
| f"Timeout fetching shard {shard.shard_id} from {shard.node_addr}" | ||
| ) from e | ||
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"Error fetching shard {shard.shard_id} from {shard.node_addr}: {e}" | ||
| ) from e | ||
|
|
||
| async def fetch_shards(self, metadata: BatchMetadata) -> list[dict[str, Any]]: | ||
| """Fetch all shards for a batch and return raw shard data.""" | ||
| if not metadata.shards: | ||
| return [] | ||
|
|
||
| connector = aiohttp.TCPConnector(limit=self.connection_limit) | ||
| async with aiohttp.ClientSession( | ||
| timeout=self.timeout, connector=connector | ||
| ) as session: | ||
| logger.info( | ||
| f"Fetching {len(metadata.shards)} shards for batch {metadata.batch_id}" | ||
| ) | ||
| tasks = [self.fetch_shard(session, shard) for shard in metadata.shards] | ||
| shard_data_list = await asyncio.gather(*tasks) | ||
| return shard_data_list | ||
|
|
||
| async def store_shard( | ||
| self, | ||
| session: aiohttp.ClientSession, | ||
| node_addr: str, | ||
| shard_id: str, | ||
| global_step: int, | ||
| data: dict[str, Any], | ||
| ) -> None: | ||
| """Store a shard on a node.""" | ||
| url = f"http://{node_addr}/data/{shard_id}?global_step={global_step}" | ||
|
|
||
| # Serialize data | ||
| buffer = io.BytesIO() | ||
| pickle.dump(data, buffer) | ||
| data_bytes = buffer.getvalue() | ||
|
|
||
| try: | ||
| async with session.put( | ||
| url, data=data_bytes, timeout=self.timeout | ||
| ) as response: | ||
| if response.status != 200: | ||
| error_text = await response.text() | ||
| raise RuntimeError( | ||
| f"Failed to store shard {shard_id} to {node_addr}: " | ||
| f"HTTP {response.status} - {error_text}" | ||
| ) | ||
|
|
||
| logger.debug( | ||
| f"Stored shard {shard_id} to {node_addr} ({len(data_bytes)} bytes)" | ||
| ) | ||
|
|
||
| except asyncio.TimeoutError as e: | ||
| raise RuntimeError( | ||
| f"Timeout storing shard {shard_id} to {node_addr}" | ||
| ) from e | ||
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"Error storing shard {shard_id} to {node_addr}: {e}" | ||
| ) from e | ||
|
|
||
| async def clear_batches(self, node_addrs: set[str], global_step: int) -> None: | ||
| """Clear old data on multiple nodes.""" | ||
| connector = aiohttp.TCPConnector(limit=self.connection_limit) | ||
| async with aiohttp.ClientSession( | ||
| timeout=self.timeout, connector=connector | ||
| ) as session: | ||
| tasks = [ | ||
| self._clear_node(session, node_addr, global_step) | ||
| for node_addr in node_addrs | ||
| ] | ||
| await asyncio.gather(*tasks, return_exceptions=True) | ||
|
|
||
| async def _clear_node( | ||
| self, session: aiohttp.ClientSession, node_addr: str, global_step: int | ||
| ) -> None: | ||
| """Clear old data on a single node.""" | ||
| url = f"http://{node_addr}/data/clear?global_step={global_step}" | ||
|
|
||
| try: | ||
| async with session.delete(url, timeout=self.timeout) as response: | ||
| if response.status != 200: | ||
| error_text = await response.text() | ||
| logger.warning( | ||
| f"Failed to clear data on {node_addr}: " | ||
| f"HTTP {response.status} - {error_text}" | ||
| ) | ||
| else: | ||
| result = await response.json() | ||
| logger.debug( | ||
| f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| logger.warning(f"Error clearing data on {node_addr}: {e}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass, field | ||
|
|
||
|
|
||
| @dataclass | ||
| class TensorMetadata: | ||
| """Metadata for a tensor field.""" | ||
|
|
||
| shape: tuple[int, ...] | ||
| dtype: str | ||
| device: str = "cpu" | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"TensorMetadata(shape={self.shape}, dtype={self.dtype}, device={self.device})" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ShardMetadata: | ||
| """Metadata for a single (sub-)shard stored on one node. | ||
| A logical batch can be composed of multiple shards, and a single physical | ||
| shard can be split into multiple logical sub-shards via offset and batch_size. | ||
| """ | ||
|
|
||
| node_id: str | ||
| node_addr: str | ||
| shard_id: str | ||
| batch_size: int | ||
| offset: int = 0 | ||
daihaowz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| fields: dict[str, TensorMetadata] = field(default_factory=dict) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"ShardMetadata(node_id={self.node_id}, node_addr={self.node_addr}, " | ||
| f"shard_id={self.shard_id}, offset={self.offset}, " | ||
| f"batch_size={self.batch_size}, fields={list(self.fields.keys())})" | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchMetadata: | ||
| """Metadata for a distributed batch sharded across multiple nodes.""" | ||
|
|
||
| batch_id: str | ||
| global_step: int | ||
daihaowz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| total_batch_size: int | ||
| shards: list[ShardMetadata] = field(default_factory=list) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"BatchMetadata(batch_id={self.batch_id}, global_step={self.global_step}, " | ||
| f"total_batch_size={self.total_batch_size}, num_shards={len(self.shards)}, " | ||
| f"shards={self.shards})" | ||
| ) | ||
|
|
||
| def get_all_node_addrs(self) -> set[str]: | ||
| """Get all unique node addresses in this batch.""" | ||
| return {shard.node_addr for shard in self.shards} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.