Skip to content
6 changes: 4 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,16 @@ class Response:
generate_token_len (int): the response token length.
input_token_len (int): the input prompt token length. Note that it may
contains chat template part.
session_id (int): the id for running the session. Basically, it refers
to the position index of the input request batch.
session_id (int): the id for running the session.
finish_reason ('stop' | 'length' | None): the reason the model stopped
generating tokens. This will be 'stop' if the model hit a natural
stop point or a provided stop sequence, 'length' if the maximum
number of tokens specified in the request was reached.
token_ids: (List[int]): the output token ids.
logprobs: (List[Dict[int, float]]): the top logprobs for each output
position.
index_id (int): it refers to the position index of the input request
batch
"""
text: str
generate_token_len: int
Expand All @@ -270,6 +271,7 @@ class Response:
finish_reason: Optional[Literal['stop', 'length']] = None
token_ids: List[int] = field(default_factory=list)
logprobs: List[Dict[int, float]] = None
index_id: int = 0


@dataclass
Expand Down
41 changes: 15 additions & 26 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(self,
self.gens_set = set()
for i in range(self.instance_num):
self.gens_set.add(self.engine.create_instance())
self._session_ids = count(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"self._session_id" is better. After all, it is only one int


def _build_turbomind(
self,
Expand Down Expand Up @@ -353,7 +354,6 @@ async def get_generator(self, stop: bool, session_id: int):
def batch_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
session_ids: Union[List[int], int] = None,
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig],
EngineGenerationConfig,
Expand All @@ -368,8 +368,6 @@ def batch_infer(
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
session_ids (List[int] | int): a batch of session_ids. If not
provided, it will be [0, number of prompts]
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
Expand All @@ -392,14 +390,10 @@ def batch_infer(
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
prompt_num = len(prompts)
if session_ids is None:
session_ids = range(prompt_num)
elif isinstance(session_ids, int):
session_ids = [session_ids]
assert len(prompts) == len(session_ids), \
'input session_ids length differs from the length of prompts'
session_ids = [next(self._session_ids) for _ in range(prompt_num)]
outputs = [
Response('', 0, 0, session_ids[i]) for i in range(prompt_num)
Response('', 0, 0, session_ids[i], index_id=i)
for i in range(prompt_num)
]
generators = []
if use_tqdm:
Expand Down Expand Up @@ -443,7 +437,6 @@ async def gather():
def stream_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
session_ids: Union[List[int], int] = None,
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig],
EngineGenerationConfig,
Expand All @@ -457,8 +450,6 @@ def stream_infer(
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
session_ids (List[int] | int): a batch of session_ids. If not
provided, it will be [0, number of prompts]
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
Expand All @@ -479,12 +470,7 @@ def stream_infer(
gen_config = [gen_config] * len(prompts)
assert len(prompts) == len(gen_config),\
'input gen_confg length differs from the length of prompts' # noqa
if session_ids is None:
session_ids = range(len(prompts))
elif isinstance(session_ids, int):
session_ids = [session_ids]
assert len(prompts) == len(session_ids), \
'input session_ids length differs from the length of prompts' # noqa
session_ids = [next(self._session_ids) for _ in range(len(prompts))]
outputs = Queue()
generators = []
for i, prompt in enumerate(prompts):
Expand All @@ -502,15 +488,18 @@ def stream_infer(
async def _inner_call(i, generator):
async for out in generator:
outputs.put(
Response(out.response, out.generate_token_len,
out.input_token_len, i, out.finish_reason,
out.token_ids, out.logprobs))
Response(out.response,
out.generate_token_len,
out.input_token_len,
session_ids[i],
out.finish_reason,
out.token_ids,
out.logprobs,
index_id=i))

async def gather():
await asyncio.gather(*[
_inner_call(session_ids[i], generators[i])
for i in range(len(prompts))
])
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(prompts))])
outputs.put(None)

loop = _get_event_loop()
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/utils/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
check_cuda_error(cudaGetDeviceCount(&device_count));
cudaMemPool_t mempool;
check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));
#if TM_ENABLE_CUSTOM_ALL_REDUCE
cudaMemAccessDesc desc = {};
int peer_access_available = 0;
for (int i = 0; i < device_count; i++) {
Expand All @@ -184,6 +185,7 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
desc.flags = cudaMemAccessFlagsProtReadWrite;
check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1));
}
#endif
// set memory pool threshold to avoid shrinking the pool
uint64_t setVal = UINT64_MAX;
check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal));
Expand Down