@@ -351,15 +351,16 @@ def generate(
351351
352352 def chat (
353353 self ,
354- messages : List [ChatCompletionMessageParam ],
354+ conversations : Union [List [ChatCompletionMessageParam ],
355+ List [List [ChatCompletionMessageParam ]]],
355356 sampling_params : Optional [Union [SamplingParams ,
356357 List [SamplingParams ]]] = None ,
357358 use_tqdm : bool = True ,
358359 lora_request : Optional [LoRARequest ] = None ,
359360 chat_template : Optional [str ] = None ,
360361 add_generation_prompt : bool = True ,
361362 tools : Optional [List [Dict [str , Any ]]] = None ,
362- ) -> List [RequestOutput ]:
363+ ) -> Union [ List [List [ RequestOutput ]], List [ RequestOutput ] ]:
363364 """
364365 Generate responses for a chat conversation.
365366
@@ -371,8 +372,9 @@ def chat(
371372 to the OpenAI API.
372373
373374 Args:
374- messages: A single conversation represented as a list of messages.
375- Each message is a dictionary with 'role' and 'content' keys.
375+ conversations: A list or a single conversation represented as a list
376+ of messages. Each message is a dictionary with 'role' and
377+ 'content' keys.
376378 sampling_params: The sampling parameters for text generation.
377379 If None, we use the default sampling parameters. When it
378380 is a single value, it is applied to every prompt. When it
@@ -386,49 +388,66 @@ def chat(
386388 to each message.
387389
388390 Returns:
389- A list of ``RequestOutput`` objects containing the generated
390- responses in the same order as the input messages.
391+ A list of lists or single list of ``RequestOutput`` objects
392+ containing the generated responses in the same order as the input
393+ conversations and messages.
391394 """
395+ list_of_conversations : List [List [ChatCompletionMessageParam ]]
392396
393- tokenizer = self .get_tokenizer ()
394- model_config = self .llm_engine .get_model_config ()
395-
396- conversation , mm_data = parse_chat_messages (messages , model_config ,
397- tokenizer )
398-
399- prompt : Union [str , List [int ]]
400- if isinstance (tokenizer , MistralTokenizer ):
401- prompt = apply_mistral_chat_template (
402- tokenizer ,
403- messages = messages ,
404- chat_template = chat_template ,
405- add_generation_prompt = add_generation_prompt ,
406- tools = tools ,
407- )
397+ # Handle multi and single conversations
398+ if is_list_of (conversations , list ):
399+ # conversations is List[List[...]]
400+ list_of_conversations = conversations
408401 else :
409- prompt = apply_hf_chat_template (
410- tokenizer ,
411- conversation = conversation ,
412- chat_template = chat_template ,
413- add_generation_prompt = add_generation_prompt ,
414- tools = tools ,
415- )
402+ # conversations is List[...]
403+ list_of_conversations = [conversations ]
404+
405+ outputs = []
406+
407+ for messages in list_of_conversations :
408+ tokenizer = self .get_tokenizer ()
409+ model_config = self .llm_engine .get_model_config ()
410+
411+ conversation , mm_data = parse_chat_messages (
412+ messages , model_config , tokenizer )
413+
414+ prompt : Union [str , List [int ]]
415+ if isinstance (tokenizer , MistralTokenizer ):
416+ prompt = apply_mistral_chat_template (
417+ tokenizer ,
418+ messages = messages ,
419+ chat_template = chat_template ,
420+ add_generation_prompt = add_generation_prompt ,
421+ tools = tools ,
422+ )
423+ else :
424+ prompt = apply_hf_chat_template (
425+ tokenizer ,
426+ conversation = conversation ,
427+ chat_template = chat_template ,
428+ add_generation_prompt = add_generation_prompt ,
429+ tools = tools ,
430+ )
431+
432+ inputs : PromptInputs
433+ if is_list_of (prompt , int ):
434+ inputs = TokensPrompt (prompt_token_ids = prompt )
435+ else :
436+ inputs = TextPrompt (prompt = prompt )
416437
417- inputs : PromptInputs
418- if is_list_of (prompt , int ):
419- inputs = TokensPrompt (prompt_token_ids = prompt )
420- else :
421- inputs = TextPrompt (prompt = prompt )
438+ if mm_data is not None :
439+ inputs ["multi_modal_data" ] = mm_data
422440
423- if mm_data is not None :
424- inputs ["multi_modal_data" ] = mm_data
441+ out = self .generate (
442+ inputs ,
443+ sampling_params = sampling_params ,
444+ use_tqdm = use_tqdm ,
445+ lora_request = lora_request ,
446+ )
447+ outputs .append (out )
425448
426- return self .generate (
427- inputs ,
428- sampling_params = sampling_params ,
429- use_tqdm = use_tqdm ,
430- lora_request = lora_request ,
431- )
449+ # When conversations is List[...], return a single list.
450+ return outputs if len (outputs ) > 1 else outputs [0 ]
432451
433452 @overload # LEGACY: single (prompt + optional token ids)
434453 def encode (
0 commit comments