Skip to content
32 changes: 27 additions & 5 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ async def get_generator(self, stop: bool, session_id: int):
def batch_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
session_ids: Union[List[int], int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing session_ids will make the API hard to understand.
Is there any way we can handle it internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stream_infer supports batch inputs, we need to match the inputs and outputs and we distinguish it by session_id currently. If we don't introduce the session_ids, we need to change the output like below which i present the ith input in the batch.

                i, out = outputs.get(timeout=0.001)
                if out is None:
                    break
                yield i, out

gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig],
EngineGenerationConfig,
Expand All @@ -367,6 +368,8 @@ def batch_infer(
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
session_ids (List[int] | int): a batch of session_ids. If not
provided, it will be [0, number of prompts]
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
Expand All @@ -389,15 +392,23 @@ def batch_infer(
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
prompt_num = len(prompts)
outputs = [Response('', 0, 0, i) for i in range(prompt_num)]
if session_ids is None:
session_ids = range(prompt_num)
elif isinstance(session_ids, int):
session_ids = [session_ids]
assert len(prompts) == len(session_ids), \
'input session_ids length differs from the length of prompts'
outputs = [
Response('', 0, 0, session_ids[i]) for i in range(prompt_num)
]
generators = []
if use_tqdm:
import tqdm
pbar = tqdm.tqdm(total=len(prompts))
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
session_ids[i],
gen_config=gen_config[i],
stream_response=True,
sequence_start=True,
Expand Down Expand Up @@ -432,6 +443,7 @@ async def gather():
def stream_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
session_ids: Union[List[int], int] = None,
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig],
EngineGenerationConfig,
Expand All @@ -445,6 +457,8 @@ def stream_infer(
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
session_ids (List[int] | int): a batch of session_ids. If not
provided, it will be [0, number of prompts]
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
Expand All @@ -465,12 +479,18 @@ def stream_infer(
gen_config = [gen_config] * len(prompts)
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
if session_ids is None:
session_ids = range(len(prompts))
elif isinstance(session_ids, int):
session_ids = [session_ids]
assert len(prompts) == len(session_ids), \
'input session_ids length differs from the length of prompts' # noqa
outputs = Queue()
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
session_ids[i],
gen_config=gen_config[i],
stream_response=True,
sequence_start=True,
Expand All @@ -487,8 +507,10 @@ async def _inner_call(i, generator):
out.token_ids, out.logprobs))

async def gather():
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(prompts))])
await asyncio.gather(*[
_inner_call(session_ids[i], generators[i])
for i in range(len(prompts))
])
outputs.put(None)

loop = _get_event_loop()
Expand Down