Skip to content

Commit c685f77

Browse files
irexyclzhangzz
andauthored
Fix duplicated session_id when pipeline is used by multithreads (#2134)
* add session_ids arg for multithread use of pipeline.stream_infer * Revert "disable peer access code (#2082)" This reverts commit 263e8cf. * Revert "Revert "disable peer access code (#2082)"" This reverts commit 2b74d46. * update * add peer allocator * fix lint * check cuda error * fix comments * fix wrong allocator --------- Co-authored-by: Li Zhang <[email protected]>
1 parent 687f242 commit c685f77

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

lmdeploy/messages.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,16 @@ class Response:
249249
generate_token_len (int): the response token length.
250250
input_token_len (int): the input prompt token length. Note that it may
251251
contains chat template part.
252-
session_id (int): the id for running the session. Basically, it refers
253-
to the position index of the input request batch.
252+
session_id (int): the id for running the session.
254253
finish_reason ('stop' | 'length' | None): the reason the model stopped
255254
generating tokens. This will be 'stop' if the model hit a natural
256255
stop point or a provided stop sequence, 'length' if the maximum
257256
number of tokens specified in the request was reached.
258257
token_ids: (List[int]): the output token ids.
259258
logprobs: (List[Dict[int, float]]): the top logprobs for each output
260259
position.
260+
index (int): it refers to the position index of the input request
261+
batch
261262
"""
262263
text: str
263264
generate_token_len: int
@@ -266,6 +267,7 @@ class Response:
266267
finish_reason: Optional[Literal['stop', 'length']] = None
267268
token_ids: List[int] = field(default_factory=list)
268269
logprobs: List[Dict[int, float]] = None
270+
index: int = 0
269271

270272

271273
@dataclass

lmdeploy/serve/async_engine.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self,
180180
self.gens_set = set()
181181
for i in range(self.instance_num):
182182
self.gens_set.add(self.engine.create_instance())
183+
self._session_id = count(0)
183184

184185
def _build_turbomind(
185186
self,
@@ -328,15 +329,19 @@ def batch_infer(
328329
assert len(prompts) == len(gen_config),\
329330
'input gen_confg length differs from the length of prompts' # noqa
330331
prompt_num = len(prompts)
331-
outputs = [Response('', 0, 0, i) for i in range(prompt_num)]
332+
session_ids = [next(self._session_id) for _ in range(prompt_num)]
333+
outputs = [
334+
Response('', 0, 0, session_ids[i], index=i)
335+
for i in range(prompt_num)
336+
]
332337
generators = []
333338
if use_tqdm:
334339
import tqdm
335340
pbar = tqdm.tqdm(total=len(prompts))
336341
for i, prompt in enumerate(prompts):
337342
generators.append(
338343
self.generate(prompt,
339-
i,
344+
session_ids[i],
340345
gen_config=gen_config[i],
341346
stream_response=True,
342347
sequence_start=True,
@@ -404,12 +409,13 @@ def stream_infer(
404409
gen_config = [gen_config] * len(prompts)
405410
assert len(prompts) == len(gen_config),\
406411
'input gen_confg length differs from the length of prompts' # noqa
412+
session_ids = [next(self._session_id) for _ in range(len(prompts))]
407413
outputs = Queue()
408414
generators = []
409415
for i, prompt in enumerate(prompts):
410416
generators.append(
411417
self.generate(prompt,
412-
i,
418+
session_ids[i],
413419
gen_config=gen_config[i],
414420
stream_response=True,
415421
sequence_start=True,
@@ -421,9 +427,14 @@ def stream_infer(
421427
async def _inner_call(i, generator):
422428
async for out in generator:
423429
outputs.put(
424-
Response(out.response, out.generate_token_len,
425-
out.input_token_len, i, out.finish_reason,
426-
out.token_ids, out.logprobs))
430+
Response(out.response,
431+
out.generate_token_len,
432+
out.input_token_len,
433+
session_ids[i],
434+
out.finish_reason,
435+
out.token_ids,
436+
out.logprobs,
437+
index=i))
427438

428439
async def gather():
429440
await asyncio.gather(

0 commit comments

Comments
 (0)