Skip to content
6 changes: 4 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,16 @@ class Response:
generate_token_len (int): the response token length.
input_token_len (int): the input prompt token length. Note that it may
contains chat template part.
session_id (int): the id for running the session. Basically, it refers
to the position index of the input request batch.
session_id (int): the id for running the session.
finish_reason ('stop' | 'length' | None): the reason the model stopped
generating tokens. This will be 'stop' if the model hit a natural
stop point or a provided stop sequence, 'length' if the maximum
number of tokens specified in the request was reached.
token_ids: (List[int]): the output token ids.
logprobs: (List[Dict[int, float]]): the top logprobs for each output
position.
index (int): it refers to the position index of the input request
batch
"""
text: str
generate_token_len: int
Expand All @@ -266,6 +267,7 @@ class Response:
finish_reason: Optional[Literal['stop', 'length']] = None
token_ids: List[int] = field(default_factory=list)
logprobs: List[Dict[int, float]] = None
index: int = 0


@dataclass
Expand Down
23 changes: 17 additions & 6 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(self,
self.gens_set = set()
for i in range(self.instance_num):
self.gens_set.add(self.engine.create_instance())
self._session_id = count(0)

def _build_turbomind(
self,
Expand Down Expand Up @@ -328,15 +329,19 @@ def batch_infer(
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
prompt_num = len(prompts)
outputs = [Response('', 0, 0, i) for i in range(prompt_num)]
session_ids = [next(self._session_id) for _ in range(prompt_num)]
outputs = [
Response('', 0, 0, session_ids[i], index=i)
for i in range(prompt_num)
]
generators = []
if use_tqdm:
import tqdm
pbar = tqdm.tqdm(total=len(prompts))
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
session_ids[i],
gen_config=gen_config[i],
stream_response=True,
sequence_start=True,
Expand Down Expand Up @@ -404,12 +409,13 @@ def stream_infer(
gen_config = [gen_config] * len(prompts)
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
session_ids = [next(self._session_id) for _ in range(len(prompts))]
outputs = Queue()
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
session_ids[i],
gen_config=gen_config[i],
stream_response=True,
sequence_start=True,
Expand All @@ -421,9 +427,14 @@ def stream_infer(
async def _inner_call(i, generator):
async for out in generator:
outputs.put(
Response(out.response, out.generate_token_len,
out.input_token_len, i, out.finish_reason,
out.token_ids, out.logprobs))
Response(out.response,
out.generate_token_len,
out.input_token_len,
session_ids[i],
out.finish_reason,
out.token_ids,
out.logprobs,
index=i))

async def gather():
await asyncio.gather(
Expand Down