4444from vllm .prompt_adapter .request import PromptAdapterRequest
4545from vllm .sampling_params import RequestOutputKind , SamplingParams
4646from vllm .sequence import (EmbeddingSequenceGroupOutput , ExecuteModelRequest ,
47- Sequence , SequenceGroup , SequenceGroupMetadata ,
48- SequenceGroupOutput , SequenceStatus )
47+ ParallelSampleSequenceGroup , Sequence ,
48+ SequenceGroup , SequenceGroupBase ,
49+ SequenceGroupMetadata , SequenceGroupOutput ,
50+ SequenceStatus )
4951from vllm .tracing import (SpanAttributes , SpanKind , extract_trace_context ,
5052 init_tracer )
5153from vllm .transformers_utils .config import try_get_generation_config
@@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
474476 ),
475477 ))
476478
479+ self .seq_id_to_seq_group : Dict [str , SequenceGroupBase ] = {}
480+
477481 def _initialize_kv_caches (self ) -> None :
478482 """Initialize the KV cache in the worker(s).
479483
@@ -648,7 +652,10 @@ def _add_processed_request(
648652 prompt_adapter_request : Optional [PromptAdapterRequest ],
649653 trace_headers : Optional [Mapping [str , str ]] = None ,
650654 priority : int = 0 ,
651- ) -> None :
655+ ) -> SequenceGroup :
656+ """Add a processed request to the engine's request pool.
657+ return the created sequence group.
658+ """
652659 self ._validate_model_inputs (processed_inputs )
653660 # Create the sequences.
654661 block_size = self .cache_config .block_size
@@ -701,6 +708,8 @@ def _add_processed_request(
701708 min_cost_scheduler = self .scheduler [costs .index (min (costs ))]
702709 min_cost_scheduler .add_seq_group (seq_group )
703710
711+ return seq_group
712+
704713 def stop_remote_worker_execution_loop (self ) -> None :
705714 self .model_executor .stop_remote_worker_execution_loop ()
706715
@@ -754,7 +763,7 @@ def add_request(
754763 trace_headers : Optional [Mapping [str , str ]] = None ,
755764 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
756765 priority : int = 0 ,
757- ) -> None :
766+ ) -> Optional [ SequenceGroup ] :
758767 ...
759768
760769 @overload
@@ -768,7 +777,7 @@ def add_request(
768777 trace_headers : Optional [Mapping [str , str ]] = None ,
769778 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
770779 priority : int = 0 ,
771- ) -> None :
780+ ) -> Optional [ SequenceGroup ] :
772781 ...
773782
774783 @deprecate_kwargs (
@@ -787,7 +796,7 @@ def add_request(
787796 priority : int = 0 ,
788797 * ,
789798 inputs : Optional [PromptType ] = None , # DEPRECATED
790- ) -> None :
799+ ) -> Optional [ SequenceGroup ] :
791800 """Add a request to the engine's request pool.
792801
793802 The request is added to the request pool and will be processed by the
@@ -831,6 +840,22 @@ def add_request(
831840 >>> # continue the request processing
832841 >>> ...
833842 """
843+
844+ if isinstance (params , SamplingParams ) and params .n > 1 :
845+ ParallelSampleSequenceGroup .add_request (
846+ request_id ,
847+ self ,
848+ params ,
849+ prompt = prompt ,
850+ arrival_time = arrival_time ,
851+ lora_request = lora_request ,
852+ trace_headers = trace_headers ,
853+ prompt_adapter_request = prompt_adapter_request ,
854+ priority = priority ,
855+ inputs = inputs ,
856+ )
857+ return None
858+
834859 if inputs is not None :
835860 prompt = inputs
836861 assert prompt is not None and params is not None
@@ -865,7 +890,7 @@ def add_request(
865890 params = params ,
866891 lora_request = lora_request )
867892
868- self ._add_processed_request (
893+ return self ._add_processed_request (
869894 request_id = request_id ,
870895 processed_inputs = processed_inputs ,
871896 params = processed_params ,
@@ -1182,7 +1207,9 @@ def _process_model_outputs(self,
11821207 seq_group = scheduled_seq_group .seq_group
11831208 seq_group .maybe_set_first_token_time (now )
11841209 request_output = RequestOutputFactory .create (
1185- seq_group , use_cache = self .use_cached_outputs )
1210+ seq_group ,
1211+ self .seq_id_to_seq_group ,
1212+ use_cache = self .use_cached_outputs )
11861213 if request_output :
11871214 ctx .request_outputs .append (request_output )
11881215
@@ -1222,7 +1249,9 @@ def _process_model_outputs(self,
12221249 seq_group = scheduled_seq_group .seq_group
12231250 seq_group .maybe_set_first_token_time (now )
12241251 request_output = RequestOutputFactory .create (
1225- seq_group , use_cache = self .use_cached_outputs )
1252+ seq_group ,
1253+ self .seq_id_to_seq_group ,
1254+ use_cache = self .use_cached_outputs )
12261255 if request_output :
12271256 ctx .request_outputs .append (request_output )
12281257
@@ -1241,7 +1270,10 @@ def _process_model_outputs(self,
12411270 continue
12421271
12431272 request_output = RequestOutputFactory .create (
1244- seq_group , use_cache = self .use_cached_outputs )
1273+ seq_group ,
1274+ self .seq_id_to_seq_group ,
1275+ use_cache = self .use_cached_outputs ,
1276+ )
12451277 if request_output :
12461278 ctx .request_outputs .append (request_output )
12471279
0 commit comments