Skip to content

Conversation

@irexyc
Copy link
Collaborator

@irexyc irexyc commented Jul 25, 2024

Motivation

When using pipeline.batch_infer / pipeline.stream_infer with multiple threads, the default session_ids are all start of zero which will makes batch inference impossible.

from threading import Thread
from lmdeploy import pipeline, GenerationConfig
pipe = pipeline('/mnt/140/InternLM/internlm2-chat-1_8b', log_level='INFO')


def work(ss):
  gen_config = GenerationConfig(ignore_eos=True, max_new_tokens=512)
  for i in range(10):
    for x in pipe.stream_infer('hello', gen_config=gen_config):
        pass


threads = []
for i in range(5):
  t = Thread(target=work, args=(i * 10, ))
  t.start()
  threads.append(t)

for t in threads:
  t.join()

def batch_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
session_ids: Union[List[int], int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing session_ids will make the API hard to understand.
Is there any way we can handle it internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stream_infer supports batch inputs, we need to match the inputs and outputs and we distinguish it by session_id currently. If we don't introduce the session_ids, we need to change the output like below which i present the ith input in the batch.

                i, out = outputs.get(timeout=0.001)
                if out is None:
                    break
                yield i, out

@lvhan028
Copy link
Collaborator

@AllentDan Do you remember our motivation to support batch prompts inference in streaming mode?

@AllentDan
Copy link
Collaborator

@AllentDan Do you remember our motivation to support batch prompts inference in streaming mode?

Required by XTuner and community. #636

@AllentDan
Copy link
Collaborator

AllentDan commented Jul 29, 2024

Users tend to input only one prompt in each thread in multithread situations. Shall we provide an infer function for it? The infer function is a synchronized function of async generate. It does not offer batch inference ability. Users can control the session id themselves.

@lvhan028
Copy link
Collaborator

@AllentDan Do you remember our motivation to support batch prompts inference in streaming mode?

Required by XTuner and community. #636

It didn't make sense to me.
If users request streaming output for multiple prompts, can't we recommend they use async_stream_infer and pass the prompts one by one?

@AllentDan
Copy link
Collaborator

Maybe the user is not capable of using coroutine programming. I recommend they all use the generate function directly if possible.

@lvhan028
Copy link
Collaborator

That's not my point
My point is that outputting batch prompt response in streaming mode is probably inappropriate.
Is it supported by vLLM?

@lvhan028
Copy link
Collaborator

The root is we don't want users being bothered by session_ids in multithread scenarios.

@lvhan028
Copy link
Collaborator

lvhan028 commented Aug 1, 2024

todo item after inner discussion:

  • not export session_ids in pipeline API but maintain an internally incremented global session_id
  • add index in response structure, indicating the id of the corresponding prompt in the prompt list

self.gens_set = set()
for i in range(self.instance_num):
self.gens_set.add(self.engine.create_instance())
self._session_ids = count(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"self._session_id" is better. After all, it is only one int

Copy link
Collaborator

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028
Copy link
Collaborator

lvhan028 commented Aug 8, 2024

May merge main so as to do the test

@lvhan028 lvhan028 changed the title add session_ids arg for multithread use of pipeline.stream_infer Fix duplicated session_id when pipeline is used by multithreads Aug 8, 2024
@lvhan028 lvhan028 merged commit c685f77 into InternLM:main Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants