-
Notifications
You must be signed in to change notification settings - Fork 634
Fix duplicated session_id when pipeline is used by multithreads #2134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
06d1392
add session_ids arg for multithread use of pipeline.stream_infer
irexyc 2b74d46
Revert "disable peer access code (#2082)"
irexyc ba2fe36
Revert "Revert "disable peer access code (#2082)""
irexyc 76c9fb9
update
irexyc bba878f
add peer allocator
lzhangzz d65f198
fix lint
lzhangzz 7a3cf70
check cuda error
lzhangzz 0aa0d49
fix comments
irexyc 36ca39f
fix wrong allocator
lzhangzz a9e2a62
Merge remote-tracking branch 'origin/main' into mt-batch
irexyc 8e8a622
Merge remote-tracking branch 'zl/peer-allocator' into mt-batch
irexyc 793d38c
Merge remote-tracking branch 'origin/main' into mt-batch
irexyc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -215,6 +215,7 @@ def __init__(self, | |
| self.gens_set = set() | ||
| for i in range(self.instance_num): | ||
| self.gens_set.add(self.engine.create_instance()) | ||
| self._session_ids = count(0) | ||
|
||
|
|
||
| def _build_turbomind( | ||
| self, | ||
|
|
@@ -353,7 +354,6 @@ async def get_generator(self, stop: bool, session_id: int): | |
| def batch_infer( | ||
| self, | ||
| prompts: Union[List[str], str, List[Dict], List[List[Dict]]], | ||
| session_ids: Union[List[int], int] = None, | ||
| gen_config: Optional[Union[GenerationConfig, | ||
| List[GenerationConfig], | ||
| EngineGenerationConfig, | ||
|
|
@@ -368,8 +368,6 @@ def batch_infer( | |
| prompts (List[str] | str | List[Dict] | List[Dict]): a batch of | ||
| prompts. It accepts: string prompt, a list of string prompts, | ||
| a chat history in OpenAI format or a list of chat history. | ||
| session_ids (List[int] | int): a batch of session_ids. If not | ||
| provided, it will be [0, number of prompts] | ||
| gen_config (GenerationConfig | None): a instance of or a list of | ||
| GenerationConfig. Default to None. | ||
| do_preprocess (bool): whether pre-process the messages. Default to | ||
|
|
@@ -392,14 +390,10 @@ def batch_infer( | |
| assert len(prompts) == len(gen_config),\ | ||
| 'input gen_confg length differs from the length of prompts' # noqa | ||
| prompt_num = len(prompts) | ||
| if session_ids is None: | ||
| session_ids = range(prompt_num) | ||
| elif isinstance(session_ids, int): | ||
| session_ids = [session_ids] | ||
| assert len(prompts) == len(session_ids), \ | ||
| 'input session_ids length differs from the length of prompts' | ||
| session_ids = [next(self._session_ids) for _ in range(prompt_num)] | ||
| outputs = [ | ||
| Response('', 0, 0, session_ids[i]) for i in range(prompt_num) | ||
| Response('', 0, 0, session_ids[i], index_id=i) | ||
| for i in range(prompt_num) | ||
| ] | ||
| generators = [] | ||
| if use_tqdm: | ||
|
|
@@ -443,7 +437,6 @@ async def gather(): | |
| def stream_infer( | ||
| self, | ||
| prompts: Union[List[str], str, List[Dict], List[List[Dict]]], | ||
| session_ids: Union[List[int], int] = None, | ||
| gen_config: Optional[Union[GenerationConfig, | ||
| List[GenerationConfig], | ||
| EngineGenerationConfig, | ||
|
|
@@ -457,8 +450,6 @@ def stream_infer( | |
| prompts (List[str] | str | List[Dict] | List[Dict]): a batch of | ||
| prompts. It accepts: string prompt, a list of string prompts, | ||
| a chat history in OpenAI format or a list of chat history. | ||
| session_ids (List[int] | int): a batch of session_ids. If not | ||
| provided, it will be [0, number of prompts] | ||
| gen_config (GenerationConfig | None): a instance of or a list of | ||
| GenerationConfig. Default to None. | ||
| do_preprocess (bool): whether pre-process the messages. Default to | ||
|
|
@@ -479,12 +470,7 @@ def stream_infer( | |
| gen_config = [gen_config] * len(prompts) | ||
| assert len(prompts) == len(gen_config),\ | ||
| 'input gen_confg length differs from the length of prompts' # noqa | ||
| if session_ids is None: | ||
| session_ids = range(len(prompts)) | ||
| elif isinstance(session_ids, int): | ||
| session_ids = [session_ids] | ||
| assert len(prompts) == len(session_ids), \ | ||
| 'input session_ids length differs from the length of prompts' # noqa | ||
| session_ids = [next(self._session_ids) for _ in range(len(prompts))] | ||
| outputs = Queue() | ||
| generators = [] | ||
| for i, prompt in enumerate(prompts): | ||
|
|
@@ -502,15 +488,18 @@ def stream_infer( | |
| async def _inner_call(i, generator): | ||
| async for out in generator: | ||
| outputs.put( | ||
| Response(out.response, out.generate_token_len, | ||
| out.input_token_len, i, out.finish_reason, | ||
| out.token_ids, out.logprobs)) | ||
| Response(out.response, | ||
| out.generate_token_len, | ||
| out.input_token_len, | ||
| session_ids[i], | ||
| out.finish_reason, | ||
| out.token_ids, | ||
| out.logprobs, | ||
| index_id=i)) | ||
|
|
||
| async def gather(): | ||
| await asyncio.gather(*[ | ||
| _inner_call(session_ids[i], generators[i]) | ||
| for i in range(len(prompts)) | ||
| ]) | ||
| await asyncio.gather( | ||
| *[_inner_call(i, generators[i]) for i in range(len(prompts))]) | ||
| outputs.put(None) | ||
|
|
||
| loop = _get_event_loop() | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.