diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fba18f197074..24b1c9a93126 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -76,11 +76,6 @@ def test_models( model_executor: str, enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if not envs.VLLM_USE_V1: if async_scheduling: pytest.skip("async_scheduling only supported in v1.") @@ -164,11 +159,6 @@ def test_models_distributed( extra_env: dict[str, str], enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 3d56291bc793..0e3fc82f0c03 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -36,7 +36,6 @@ def default_server_args() -> list[str]: "--enforce-eager", # Prompt Embeds server args "--enable-prompt-embeds", - "--no-enable-chunked-prefill", ] diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index a5aa1e3f4974..c14e71cbdb96 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - # Note: can be removed when - # https://github.com/vllm-project/vllm/pull/24278 finished - if current_platform.is_cpu() and use_prompt_embeds: - pytest.skip("Skipping use_prompt_embeds=True with " - "V1-only CPU backend.") - with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fb5beab77b27..63282c425350 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1513,12 +1513,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No text embedding inputs so far. - if self.enable_prompt_embeds: - _raise_or_fallback(feature_name="--enable-prompt-embeds", - recommend_to_remove=False) - return False - # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, @@ -1651,6 +1645,13 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: "models in V0 and has been disabled.") self.enable_prefix_caching = False + if self.enable_prompt_embeds: + logger.warning( + "--enable-prompt-embeds and --enable-prefix-caching " + "are not supported together in V0. Prefix caching has " + "been disabled.") + self.enable_prefix_caching = False + # Set max_num_seqs to 256 for VLLM_V0. if self.max_num_seqs is None: self.max_num_seqs = 256 @@ -1664,6 +1665,17 @@ def _set_default_args_v1(self, usage_context: UsageContext, # For pooling tasks the default is False if model_config.runner_type != "pooling": self.enable_chunked_prefill = True + + # TODO: When prefix caching supports prompt embeds inputs, this + # check can be removed. + if (self.enable_prompt_embeds + and self.enable_prefix_caching is not False): + logger.warning( + "--enable-prompt-embeds and --enable-prefix-caching " + "are not supported together in V1. Prefix caching has " + "been disabled.") + self.enable_prefix_caching = False + if self.enable_prefix_caching is None: self.enable_prefix_caching = True else: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7ad8e73d89d5..6b54511a66f3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel): # --8<-- [end:completion-sampling-params] # --8<-- [start:completion-extra-params] + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None add_special_tokens: bool = Field( default=True, description=( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index f13381ecd9ff..d4013a69e99f 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None: pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +def length_from_prompt_token_ids_or_embeds( + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], +) -> int: + """Calculate the request length (in number of tokens) give either + prompt_token_ids or prompt_embeds. + """ + prompt_token_len = None if prompt_token_ids is None else len( + prompt_token_ids) + prompt_embeds_len = \ + None if prompt_embeds is None else len(prompt_embeds) + + if prompt_token_len is None: + if prompt_embeds_len is None: + raise ValueError( + "Neither prompt_token_ids nor prompt_embeds were defined.") + return prompt_embeds_len + else: + if (prompt_embeds_len is not None + and prompt_embeds_len != prompt_token_len): + raise ValueError( + "Prompt token ids and prompt embeds had different lengths" + f" prompt_token_ids={prompt_token_len}" + f" prompt_embeds={prompt_embeds_len}") + return prompt_token_len diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 3ec5b91bf286..209fc2a4404f 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: import numpy as np import numpy.typing as npt + import torch from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) @@ -26,13 +27,14 @@ class NewRequestData: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: Optional[LoRARequest] + prompt_embeds: Optional[torch.Tensor] = None @classmethod def from_request( @@ -49,9 +51,12 @@ def from_request( block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, ) - def __repr__(self): + def __repr__(self) -> str: + prompt_embeds_shape = (self.prompt_embeds.shape + if self.prompt_embeds else None) return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," @@ -59,19 +64,26 @@ def __repr__(self): f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" ")") # Version of __repr__ with the prompt data obfuscated - def anon_repr(self): + def anon_repr(self) -> str: + prompt_token_ids_len = len( + self.prompt_token_ids + ) if self.prompt_token_ids is not None else None + prompt_embeds_shape = (self.prompt_embeds.shape + if self.prompt_embeds else None) return (f"NewRequestData(" f"req_id={self.req_id}," - f"prompt_token_ids_len={len(self.prompt_token_ids)}," + f"prompt_token_ids_len={prompt_token_ids_len}," f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" ")") diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index dec4abec519b..345f5a464c2c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -47,7 +47,7 @@ class EngineCoreRequest( gc=False): # type: ignore[call-arg] request_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: Optional[list[MultiModalFeatureSpec]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] @@ -56,6 +56,7 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] data_parallel_rank: Optional[int] + prompt_embeds: Optional[torch.Tensor] = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index cf4b06db843b..8aa36d6a439c 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) @@ -179,11 +180,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, self.tokenizer: Tokenizer = tokenizer._tokenizer # Find a safe place to start. - prompt_suffix = request.prompt_token_ids + prompt_token_ids = request.prompt_token_ids or [] + prompt_suffix = prompt_token_ids prompt_len = len(prompt_suffix) if prompt_len > 4: for i in range(4, min(prompt_len + 1, 24)): - suffix = request.prompt_token_ids[-i:] + suffix = prompt_token_ids[-i:] if '�' not in self.tokenizer.decode(suffix): prompt_suffix = suffix break @@ -260,16 +262,25 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): params = request.sampling_params assert params is not None + self.prompt_len = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) + # Metadata for incremental detokenization. - self.tokens, self.prefix_offset, self.read_offset = ( - convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=params.skip_special_tokens, - )) + if request.prompt_token_ids is not None: + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=params.skip_special_tokens, + )) + else: + # Prompt embedding requests cannot be detokenized, in general. + self.tokens = [""] * self.prompt_len + self.prefix_offset = 0 + self.read_offest = 0 - self.token_ids.extend(request.prompt_token_ids) - self.prompt_len = len(request.prompt_token_ids) + self.token_ids.extend(request.prompt_token_ids + or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens self.spaces_between_special_tokens = ( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5dad63988daa..c17dc3e204ec 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,6 +14,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, Tracer, extract_trace_context) from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -86,7 +87,8 @@ def __init__( lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], logprobs_processor: Optional[LogprobsProcessor], detokenizer: Optional[IncrementalDetokenizer], max_tokens_param: Optional[int], @@ -104,7 +106,9 @@ def __init__( self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.prompt_len = len(prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.prompt_len = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param @@ -165,6 +169,7 @@ def from_new_request( output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, @@ -223,6 +228,8 @@ def _new_request_output( first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 + # Prompt embeddings are currently not supported by pooling requests. + assert self.prompt_token_ids is not None return PoolingRequestOutput( request_id=request_id, outputs=first_output, @@ -236,10 +243,15 @@ def _new_request_output( else: prompt_logprobs = self.logprobs_processor.prompt_logprobs + # If prompt embeds were used, put placeholder prompt token ids + prompt_token_ids = self.prompt_token_ids + if prompt_token_ids is None and self.prompt_embeds is not None: + prompt_token_ids = [0] * len(self.prompt_embeds) + return RequestOutput( request_id=request_id, prompt=self.prompt, - prompt_token_ids=self.prompt_token_ids, + prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, outputs=cast(list[CompletionOutput], outputs), finished=finished, @@ -469,6 +481,8 @@ def do_tracing(self, engine_core_output: EngineCoreOutput, arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) trace_context = extract_trace_context(engine_core_output.trace_headers) + prompt_length = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds) with (self.tracer.start_as_current_span( "llm_request", kind=SpanKind.SERVER, @@ -488,7 +502,7 @@ def do_tracing(self, engine_core_output: EngineCoreOutput, span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time) span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - len(req_state.prompt_token_ids)) + prompt_length) span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, metrics.num_generation_tokens) span.set_attribute( @@ -544,7 +558,8 @@ def _update_stats_from_finished(self, req_state: RequestState, assert req_state.stats is not None iteration_stats.update_from_finished_request( finish_reason=finish_reason, - num_prompt_tokens=len(req_state.prompt_token_ids), + num_prompt_tokens=length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds), max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats) self.lora_states.finish_request(req_state) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 71f539583a1b..507e2cd3223f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,6 +19,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) @@ -390,6 +391,16 @@ def process_inputs( self._validate_model_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + # Mypy does not always properly infer the types of some elements of + # discriminated unions of TypedDicts, because of how it handles + # inheritance of TypedDict. If we explicitly extract the items we want + # we can avoid type errors from using `dict.get` later in the method. + prompt_str: Optional[str] = None if decoder_inputs[ + "type"] == "embeds" else decoder_inputs.get("prompt") + prompt_token_ids = decoder_inputs[ + "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None + prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ + "type"] == "embeds" else None sampling_params = None pooling_params = None @@ -398,9 +409,10 @@ def process_inputs( sampling_params = params.clone() # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: - sampling_params.max_tokens = ( - self.model_config.max_model_len - - len(decoder_inputs["prompt_token_ids"])) + seq_len = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds) + sampling_params.max_tokens = \ + self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) if self.tokenizer is not None: @@ -430,9 +442,10 @@ def process_inputs( identifier=decoder_mm_hashes[modality][idx], mm_position=decoder_mm_positions[modality][idx])) - return decoder_inputs.get("prompt"), EngineCoreRequest( + return prompt_str, EngineCoreRequest( request_id=request_id, - prompt_token_ids=decoder_inputs["prompt_token_ids"], + prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, @@ -461,10 +474,17 @@ def _validate_model_input( ): model_config = self.model_config - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = None if prompt_inputs[ + "type"] == "embeds" else prompt_inputs["prompt_token_ids"] + prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[ + "type"] == "embeds" else None + prompt_len = length_from_prompt_token_ids_or_embeds( + prompt_ids, prompt_embeds) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass # Prompt embeds should not have prompt_ids. else: raise ValueError(f"The {prompt_type} prompt cannot be empty") @@ -472,7 +492,7 @@ def _validate_model_input( tokenizer = None else: tokenizer = self.tokenizer - max_input_id = max(prompt_ids, default=0) + max_input_id = max(prompt_ids or [], default=0) # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while # self.model_config.get_vocab_size() is the model’s vocab size. @@ -490,7 +510,7 @@ def _validate_model_input( f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) > max_prompt_len: + if prompt_len > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( @@ -514,7 +534,7 @@ def _validate_model_input( "number of text tokens.") raise ValueError( - f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"The {prompt_type} prompt (length {prompt_len}) is " f"longer than the maximum model length of {max_prompt_len}. " f"{suggestion}") diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 145af788d237..ff10fa00c1cf 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -7,9 +7,12 @@ from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import torch + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest @@ -25,12 +28,13 @@ class Request: def __init__( self, request_id: str, - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], sampling_params: Optional[SamplingParams], pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, arrival_time: Optional[float] = None, + prompt_embeds: Optional[torch.Tensor] = None, mm_features: Optional[list[MultiModalFeatureSpec]] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, @@ -79,9 +83,13 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids - self.num_prompt_tokens = len(self.prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds) self._output_token_ids: list[int] = [] - self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self._all_token_ids: list[int] = self.prompt_token_ids.copy( + ) if self.prompt_token_ids is not None else [0 + ] * self.num_prompt_tokens self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 @@ -123,6 +131,7 @@ def from_engine_core_request( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index df944873bcaf..10cad5b53071 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -243,7 +243,7 @@ def new_req_logits_processor( def _new_state( self, params: SamplingParams, - prompt_ids: list[int], + prompt_ids: Optional[list[int]], output_ids: list[int], ) -> Optional[partial[torch.Tensor]]: """Return state representation for new request diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 60f9c0bdb631..fc655d993cb4 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -187,7 +187,8 @@ def is_argmax_invariant(self) -> bool: @staticmethod def add_request( - params: SamplingParams, _: list[int], output_tok_ids: list[int] + params: SamplingParams, _: Optional[list[int]], + output_tok_ids: list[int] ) -> Optional[tuple[int, Sequence[int], set[int]]]: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: @@ -234,7 +235,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: def process_dict_updates( req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], + Optional[T]] ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 04027359909a..a84afc2f347a 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -26,7 +26,7 @@ class MoveDirectionality(Enum): # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, SamplingParams, list[int], list[int]] +AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c8375d6f1551..50c1470c67ed 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -174,7 +174,7 @@ def _encode_tensor( ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().view(torch.uint8).numpy() + arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6717622efb80..79a392337574 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, @@ -29,7 +29,7 @@ class CachedRequestState: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] @@ -43,9 +43,11 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + prompt_embeds: Optional[torch.Tensor] = None def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) @property def num_tokens(self) -> int: @@ -63,6 +65,10 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: + if self.prompt_token_ids is None: + raise ValueError( + f"Tried to access token index {idx}, but that token was " + "provided via prompt_embeds, and its ID is unknown.") return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] @@ -109,6 +115,14 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), + device="cpu", + dtype=bool, + pin_memory=False) + # Store prompt embeddings per request to avoid OOM from large upfront + # allocation if max_model_len is big. + # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) + self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -310,15 +324,23 @@ def add_request( self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) + if request.prompt_token_ids is not None: + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + self.is_token_ids[req_index, :num_prompt_tokens] = True + else: + self.is_token_ids[req_index, :num_prompt_tokens] = False + if request.prompt_embeds is not None: + self.req_prompt_embeds[req_index] = request.prompt_embeds self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. + self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. @@ -503,6 +525,20 @@ def swap_states(self, i1: int, i2: int) -> None: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + # Swap prompt embeddings if they exist + embeds_i1 = self.req_prompt_embeds.get(i1) + embeds_i2 = self.req_prompt_embeds.get(i2) + if embeds_i1 is not None: + self.req_prompt_embeds[i2] = embeds_i1 + else: + self.req_prompt_embeds.pop(i2, None) + if embeds_i2 is not None: + self.req_prompt_embeds[i1] = embeds_i2 + else: + self.req_prompt_embeds.pop(i1, None) + self.block_table.swap_row(i1, i2) self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ @@ -592,6 +628,11 @@ def condense(self) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ + last_req_index, :num_tokens] + if last_req_index in self.req_prompt_embeds: + self.req_prompt_embeds[ + empty_index] = self.req_prompt_embeds.pop(last_req_index) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4873b586724e..93e94f2fd20c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -55,7 +55,9 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, check_use_alibi, get_dtype_size, - is_pin_memory_available, round_up, supports_dynamo) + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, round_up, + supports_dynamo) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -196,6 +198,7 @@ def __init__( cache_config.cache_dtype] self.is_pooling_model = (model_config.runner_type == 'pooling') + self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model) @@ -341,6 +344,8 @@ def __init__( self.hidden_size, dtype=self.dtype, numpy=False) + self.is_token_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 @@ -573,6 +578,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=pooling_params, @@ -806,6 +812,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, if self.input_batch.prev_sampled_token_ids is None: # Normal scheduling case self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) return # Async scheduling case, where some decode requests from the previous @@ -831,6 +839,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration # So input_ids_cpu will have all the input ids. @@ -844,6 +854,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], non_blocking=True) + self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. @@ -934,14 +945,60 @@ def _prepare_inputs( # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices), + token_indices_tensor, out=self.input_ids.cpu[:total_num_scheduled_tokens]) + is_token_ids = self.input_batch.is_token_ids.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds: + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[output_idx:output_idx + + actual_num_sched].copy_( + req_embeds[start_pos:actual_end] + ) + + output_idx += num_sched self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -1266,7 +1323,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): self.input_batch.num_computed_tokens_cpu[index] num_scheduled_tokens = \ scheduler_output.num_scheduled_tokens[req_id] - num_prompt_tokens = len(req.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, @@ -1832,6 +1890,32 @@ def _preprocess( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } + elif (self.enable_prompt_embeds and get_pp_group().is_first_rank): + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ + .nonzero(as_tuple=False) \ + .squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings( + input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) + input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the @@ -2010,6 +2094,7 @@ def _bookkeeping_sync( self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2557,6 +2642,10 @@ def _get_prompt_logprobs_dict( # Get metadata for this request. request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) @@ -2909,6 +2998,10 @@ def _dummy_run( **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + model_kwargs = self._init_model_kwargs(num_tokens) else: input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index dfa54d0ad83b..4cd0ac352de0 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -9,7 +9,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -213,7 +213,9 @@ def add_request( self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) + # TODO: copy prompt_embeds self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 43f12912707f..01a8e5c3f0db 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -387,6 +387,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None,