4444from vllm .prompt_adapter .request import PromptAdapterRequest
4545from vllm .sampling_params import RequestOutputKind , SamplingParams
4646from vllm .sequence import (EmbeddingSequenceGroupOutput , ExecuteModelRequest ,
47- Sequence , SequenceGroup , SequenceGroupMetadata ,
48- SequenceGroupOutput , SequenceStatus )
47+ ParallelSampleSequenceGroup , Sequence ,
48+ SequenceGroup , SequenceGroupBase ,
49+ SequenceGroupMetadata , SequenceGroupOutput ,
50+ SequenceStatus )
4951from vllm .tracing import (SpanAttributes , SpanKind , extract_trace_context ,
5052 init_tracer )
5153from vllm .transformers_utils .config import try_get_generation_config
@@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
474476 ),
475477 ))
476478
479+ self .seq_id_to_seq_group : Dict [str , SequenceGroupBase ] = {}
480+
477481 def _initialize_kv_caches (self ) -> None :
478482 """Initialize the KV cache in the worker(s).
479483
@@ -642,7 +646,10 @@ def _add_processed_request(
642646 prompt_adapter_request : Optional [PromptAdapterRequest ],
643647 trace_headers : Optional [Mapping [str , str ]] = None ,
644648 priority : int = 0 ,
645- ) -> None :
649+ ) -> SequenceGroup :
650+ """Add a processed request to the engine's request pool.
651+ return the created sequence group.
652+ """
646653 self ._validate_model_inputs (processed_inputs )
647654 # Create the sequences.
648655 block_size = self .cache_config .block_size
@@ -696,6 +703,8 @@ def _add_processed_request(
696703 min_cost_scheduler = self .scheduler [costs .index (min (costs ))]
697704 min_cost_scheduler .add_seq_group (seq_group )
698705
706+ return seq_group
707+
699708 def stop_remote_worker_execution_loop (self ) -> None :
700709 self .model_executor .stop_remote_worker_execution_loop ()
701710
@@ -711,7 +720,7 @@ def add_request(
711720 trace_headers : Optional [Mapping [str , str ]] = None ,
712721 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
713722 priority : int = 0 ,
714- ) -> None :
723+ ) -> Optional [ SequenceGroup ] :
715724 ...
716725
717726 @overload
@@ -725,7 +734,7 @@ def add_request(
725734 trace_headers : Optional [Mapping [str , str ]] = None ,
726735 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
727736 priority : int = 0 ,
728- ) -> None :
737+ ) -> Optional [ SequenceGroup ] :
729738 ...
730739
731740 @deprecate_kwargs (
@@ -744,7 +753,7 @@ def add_request(
744753 priority : int = 0 ,
745754 * ,
746755 inputs : Optional [PromptType ] = None , # DEPRECATED
747- ) -> None :
756+ ) -> Optional [ SequenceGroup ] :
748757 """Add a request to the engine's request pool.
749758
750759 The request is added to the request pool and will be processed by the
@@ -788,6 +797,22 @@ def add_request(
788797 >>> # continue the request processing
789798 >>> ...
790799 """
800+
801+ if isinstance (params , SamplingParams ) and params .n > 1 :
802+ ParallelSampleSequenceGroup .add_request (
803+ request_id ,
804+ self ,
805+ params ,
806+ prompt = prompt ,
807+ arrival_time = arrival_time ,
808+ lora_request = lora_request ,
809+ trace_headers = trace_headers ,
810+ prompt_adapter_request = prompt_adapter_request ,
811+ priority = priority ,
812+ inputs = inputs ,
813+ )
814+ return None
815+
791816 if inputs is not None :
792817 prompt = inputs
793818 assert prompt is not None and params is not None
@@ -818,7 +843,7 @@ def add_request(
818843 processed_inputs ["mm_processor_kwargs" ] = preprocessed_inputs .get (
819844 "mm_processor_kwargs" )
820845
821- self ._add_processed_request (
846+ return self ._add_processed_request (
822847 request_id = request_id ,
823848 processed_inputs = processed_inputs ,
824849 params = params ,
@@ -1135,7 +1160,9 @@ def _process_model_outputs(self,
11351160 seq_group = scheduled_seq_group .seq_group
11361161 seq_group .maybe_set_first_token_time (now )
11371162 request_output = RequestOutputFactory .create (
1138- seq_group , use_cache = self .use_cached_outputs )
1163+ seq_group ,
1164+ self .seq_id_to_seq_group ,
1165+ use_cache = self .use_cached_outputs )
11391166 if request_output :
11401167 ctx .request_outputs .append (request_output )
11411168
@@ -1175,7 +1202,9 @@ def _process_model_outputs(self,
11751202 seq_group = scheduled_seq_group .seq_group
11761203 seq_group .maybe_set_first_token_time (now )
11771204 request_output = RequestOutputFactory .create (
1178- seq_group , use_cache = self .use_cached_outputs )
1205+ seq_group ,
1206+ self .seq_id_to_seq_group ,
1207+ use_cache = self .use_cached_outputs )
11791208 if request_output :
11801209 ctx .request_outputs .append (request_output )
11811210
@@ -1194,7 +1223,10 @@ def _process_model_outputs(self,
11941223 continue
11951224
11961225 request_output = RequestOutputFactory .create (
1197- seq_group , use_cache = self .use_cached_outputs )
1226+ seq_group ,
1227+ self .seq_id_to_seq_group ,
1228+ use_cache = self .use_cached_outputs ,
1229+ )
11981230 if request_output :
11991231 ctx .request_outputs .append (request_output )
12001232
0 commit comments