diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 52a71ae569..a8d22bad10 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -249,8 +249,7 @@ class Response: generate_token_len (int): the response token length. input_token_len (int): the input prompt token length. Note that it may contains chat template part. - session_id (int): the id for running the session. Basically, it refers - to the position index of the input request batch. + session_id (int): the id for running the session. finish_reason ('stop' | 'length' | None): the reason the model stopped generating tokens. This will be 'stop' if the model hit a natural stop point or a provided stop sequence, 'length' if the maximum @@ -258,6 +257,8 @@ class Response: token_ids: (List[int]): the output token ids. logprobs: (List[Dict[int, float]]): the top logprobs for each output position. + index (int): it refers to the position index of the input request + batch """ text: str generate_token_len: int @@ -266,6 +267,7 @@ class Response: finish_reason: Optional[Literal['stop', 'length']] = None token_ids: List[int] = field(default_factory=list) logprobs: List[Dict[int, float]] = None + index: int = 0 @dataclass diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 039e136e43..f9c7d969dc 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -180,6 +180,7 @@ def __init__(self, self.gens_set = set() for i in range(self.instance_num): self.gens_set.add(self.engine.create_instance()) + self._session_id = count(0) def _build_turbomind( self, @@ -328,7 +329,11 @@ def batch_infer( assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa prompt_num = len(prompts) - outputs = [Response('', 0, 0, i) for i in range(prompt_num)] + session_ids = [next(self._session_id) for _ in range(prompt_num)] + outputs = [ + Response('', 0, 0, session_ids[i], index=i) + for i in range(prompt_num) + ] generators = [] if use_tqdm: import tqdm @@ -336,7 +341,7 @@ def batch_infer( for i, prompt in enumerate(prompts): generators.append( self.generate(prompt, - i, + session_ids[i], gen_config=gen_config[i], stream_response=True, sequence_start=True, @@ -404,12 +409,13 @@ def stream_infer( gen_config = [gen_config] * len(prompts) assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa + session_ids = [next(self._session_id) for _ in range(len(prompts))] outputs = Queue() generators = [] for i, prompt in enumerate(prompts): generators.append( self.generate(prompt, - i, + session_ids[i], gen_config=gen_config[i], stream_response=True, sequence_start=True, @@ -421,9 +427,14 @@ def stream_infer( async def _inner_call(i, generator): async for out in generator: outputs.put( - Response(out.response, out.generate_token_len, - out.input_token_len, i, out.finish_reason, - out.token_ids, out.logprobs)) + Response(out.response, + out.generate_token_len, + out.input_token_len, + session_ids[i], + out.finish_reason, + out.token_ids, + out.logprobs, + index=i)) async def gather(): await asyncio.gather(