Skip to content

Commit 7fd3949

Browse files
[Frontend][Core] Move merge_async_iterators to utils (#4026)
1 parent 1096717 commit 7fd3949

File tree

2 files changed

+39
-39
lines changed

2 files changed

+39
-39
lines changed

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import time
32
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
43
Optional, Tuple)
@@ -17,7 +16,7 @@
1716
from vllm.model_executor.guided_decoding import (
1817
get_guided_decoding_logits_processor)
1918
from vllm.outputs import RequestOutput
20-
from vllm.utils import random_uuid
19+
from vllm.utils import merge_async_iterators, random_uuid
2120

2221
logger = init_logger(__name__)
2322

@@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
5049
return prompt_is_tokens, prompts
5150

5251

53-
def merge_async_iterators(*iterators):
54-
"""Merge multiple asynchronous iterators into a single iterator.
55-
56-
This method handle the case where some iterators finish before others.
57-
When it yields, it yields a tuple (i, item) where i is the index of the
58-
iterator that yields the item.
59-
"""
60-
queue = asyncio.Queue()
61-
62-
finished = [False] * len(iterators)
63-
64-
async def producer(i, iterator):
65-
try:
66-
async for item in iterator:
67-
await queue.put((i, item))
68-
except Exception as e:
69-
await queue.put(e)
70-
finished[i] = True
71-
72-
_tasks = [
73-
asyncio.create_task(producer(i, iterator))
74-
for i, iterator in enumerate(iterators)
75-
]
76-
77-
async def consumer():
78-
while not all(finished) or not queue.empty():
79-
item = await queue.get()
80-
if isinstance(item, Exception):
81-
raise item
82-
yield item
83-
await asyncio.gather(*_tasks)
84-
85-
return consumer()
86-
87-
8852
class OpenAIServingCompletion(OpenAIServing):
8953

9054
def __init__(self,

vllm/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from collections import OrderedDict, defaultdict
1010
from functools import lru_cache, partial
1111
from platform import uname
12-
from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
13-
Optional, Tuple, TypeVar, Union)
12+
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
13+
Hashable, List, Optional, Tuple, TypeVar, Union)
1414

1515
import psutil
1616
import torch
@@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future:
181181
return _async_wrapper
182182

183183

184+
def merge_async_iterators(
185+
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
186+
"""Merge multiple asynchronous iterators into a single iterator.
187+
188+
This method handle the case where some iterators finish before others.
189+
When it yields, it yields a tuple (i, item) where i is the index of the
190+
iterator that yields the item.
191+
"""
192+
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
193+
194+
finished = [False] * len(iterators)
195+
196+
async def producer(i: int, iterator: AsyncIterator[T]):
197+
try:
198+
async for item in iterator:
199+
await queue.put((i, item))
200+
except Exception as e:
201+
await queue.put(e)
202+
finished[i] = True
203+
204+
_tasks = [
205+
asyncio.create_task(producer(i, iterator))
206+
for i, iterator in enumerate(iterators)
207+
]
208+
209+
async def consumer():
210+
while not all(finished) or not queue.empty():
211+
item = await queue.get()
212+
if isinstance(item, Exception):
213+
raise item
214+
yield item
215+
await asyncio.gather(*_tasks)
216+
217+
return consumer()
218+
219+
184220
def get_ip() -> str:
185221
host_ip = os.environ.get("HOST_IP")
186222
if host_ip:

0 commit comments

Comments
 (0)