|
9 | 9 | from collections import OrderedDict, defaultdict |
10 | 10 | from functools import lru_cache, partial |
11 | 11 | from platform import uname |
12 | | -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, |
13 | | - Optional, Tuple, TypeVar, Union) |
| 12 | +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, |
| 13 | + Hashable, List, Optional, Tuple, TypeVar, Union) |
14 | 14 |
|
15 | 15 | import psutil |
16 | 16 | import torch |
@@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: |
181 | 181 | return _async_wrapper |
182 | 182 |
|
183 | 183 |
|
| 184 | +def merge_async_iterators( |
| 185 | + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: |
| 186 | + """Merge multiple asynchronous iterators into a single iterator. |
| 187 | +
|
| 188 | + This method handle the case where some iterators finish before others. |
| 189 | + When it yields, it yields a tuple (i, item) where i is the index of the |
| 190 | + iterator that yields the item. |
| 191 | + """ |
| 192 | + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() |
| 193 | + |
| 194 | + finished = [False] * len(iterators) |
| 195 | + |
| 196 | + async def producer(i: int, iterator: AsyncIterator[T]): |
| 197 | + try: |
| 198 | + async for item in iterator: |
| 199 | + await queue.put((i, item)) |
| 200 | + except Exception as e: |
| 201 | + await queue.put(e) |
| 202 | + finished[i] = True |
| 203 | + |
| 204 | + _tasks = [ |
| 205 | + asyncio.create_task(producer(i, iterator)) |
| 206 | + for i, iterator in enumerate(iterators) |
| 207 | + ] |
| 208 | + |
| 209 | + async def consumer(): |
| 210 | + while not all(finished) or not queue.empty(): |
| 211 | + item = await queue.get() |
| 212 | + if isinstance(item, Exception): |
| 213 | + raise item |
| 214 | + yield item |
| 215 | + await asyncio.gather(*_tasks) |
| 216 | + |
| 217 | + return consumer() |
| 218 | + |
| 219 | + |
184 | 220 | def get_ip() -> str: |
185 | 221 | host_ip = os.environ.get("HOST_IP") |
186 | 222 | if host_ip: |
|
0 commit comments