From f70647b17f9409fee75a7c3ce095fed7a3a1a964 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 2 Oct 2024 04:10:08 -0400 Subject: [PATCH 01/18] Add mm_processor_kwargs to LLMInputs Signed-off-by: Alex-Brooks --- vllm/inputs/data.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index dfbcf9526487..e6f671177ab7 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,5 +1,5 @@ -from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, - Union) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Tuple, Union) from typing_extensions import NotRequired, TypedDict, TypeVar @@ -121,6 +121,14 @@ class LLMInputs(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + class EncoderDecoderLLMInputs(LLMInputs): """ From 40a6cf0c8ee7a3f8dc436b2fe5eab55106276e3a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 2 Oct 2024 04:10:51 -0400 Subject: [PATCH 02/18] Add (unexposed) processor kwarg resolution for inp proc Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 590ff54aea56..9b69c89ca765 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -293,8 +293,8 @@ def process_input(self, model_config: "ModelConfig", model_cls, _ = get_model_architecture(model_config) processor = self._get_model_input_processor(model_cls) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=model_config.mm_processor_kwargs) + mm_processor_kwargs = self._resolve_processor_kwargs( + inputs, processor, model_config) return processor(InputContext(model_config), inputs, **mm_processor_kwargs) @@ -305,3 +305,21 @@ def create_input_processor(self, model_config: "ModelConfig"): specific model. """ return functools.partial(self.process_input, model_config) + + @staticmethod + def _resolve_processor_kwargs( + inputs: LLMInputs, processor: InputProcessor, + model_config: "ModelConfig") -> Dict[str, Any]: + + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=inputs.get("mm_processor_kwargs")) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=model_config.mm_processor_kwargs) + + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs From bbec69b294c8abddf8bfbe0e9bc8bff473dc9bb9 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 2 Oct 2024 06:10:34 -0400 Subject: [PATCH 03/18] Push runtime processor kwargs through preproc Signed-off-by: Alex-Brooks --- vllm/inputs/data.py | 16 +++++++++ vllm/inputs/preprocess.py | 76 +++++++++++++++++++++++++++++++-------- vllm/inputs/registry.py | 3 ++ 3 files changed, 80 insertions(+), 15 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e6f671177ab7..79f9fc45c1a2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -19,6 +19,14 @@ class TextPrompt(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + class TokensPrompt(TypedDict): """Schema for a tokenized prompt.""" @@ -32,6 +40,14 @@ class TokensPrompt(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d4474a10f542..97c6168ff20d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing_extensions import assert_never @@ -20,9 +20,11 @@ logger = init_logger(__name__) PromptComponents = Tuple[Optional[str], List[int], - Optional["MultiModalDataDict"]] + Optional["MultiModalDataDict"], Optional[Dict[str, + Any]]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional["MultiModalDataDict"]] + Optional["MultiModalDataDict"], + Optional[Dict[str, Any]]] class InputPreprocessor: @@ -227,6 +229,7 @@ def _extract_prompt_components( * prompt * prompt_token_ids * multi_modal_data + * mm_processor_kwargs (request-level input processor/mapper overrides) ''' parsed = parse_singleton_prompt(prompt) @@ -239,10 +242,12 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = None + mm_processor_kwargs = None elif parsed["type"] == "tokens": prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") elif parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( @@ -251,10 +256,12 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) async def _extract_prompt_components_async( self, @@ -273,10 +280,12 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = None + mm_processor_kwargs = None elif parsed["type"] == "tokens": prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") elif parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( @@ -285,18 +294,21 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) def _build_enc_dec_llm_inputs( self, encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, + mm_processor_kwargs: Dict[str, Any], ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps if decoder_mm_data is not None: raise ValueError( @@ -314,6 +326,7 @@ def _build_enc_dec_llm_inputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, multi_modal_data=decoder_mm_data, + mm_processor_kwargs=mm_processor_kwargs, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, encoder_multi_modal_data=encoder_mm_data, @@ -367,7 +380,7 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: decoder_comps = self._extract_prompt_components( decoder_input, @@ -379,9 +392,16 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - decoder_comps = None, None, None + decoder_comps = None, None, None, None - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + mm_processor_kwargs = self._get_encoder_decoder_processor_kwargs( + prompt, encoder_comps, decoder_comps) + + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) async def _process_encoder_decoder_prompt_async( self, @@ -400,7 +420,7 @@ async def _process_encoder_decoder_prompt_async( if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: decoder_task = self._extract_prompt_components_async( decoder_input, @@ -415,23 +435,49 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_comps = None, None, None + decoder_comps = None, None, None, None + + mm_processor_kwargs = self._get_encoder_decoder_processor_kwargs( + prompt, encoder_comps, decoder_comps) + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + @staticmethod + def _get_encoder_decoder_processor_kwargs(prompt, encoder_comps, + decoder_comps): + mm_processor_kwargs = prompt.get("mm_processor_kwargs") + # Because of the common logic and types with decoder-only models, + # mm_processor_kwargs can technically be passed to individual prompts; + # users should instead pass these at the top level of the prompt, since + # the mm mapper/processor are component agnostic. + enc_kwargs = encoder_comps[-1] + dec_kwargs = decoder_comps[-1] + if enc_kwargs is not None or dec_kwargs is not None: + logger.warning( + "mm_processor_kwargs are encoder / decoder agnostic ", + "and should be passed as a top-level option for ", + "explicit encoder / decoder prompts; the provided ", + "values will not be used.") + return mm_processor_kwargs def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps + (prompt, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) def _process_decoder_only_prompt( self, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 9b69c89ca765..6bd790f29e5a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -293,6 +293,9 @@ def process_input(self, model_config: "ModelConfig", model_cls, _ = get_model_architecture(model_config) processor = self._get_model_input_processor(model_cls) + # Handle multimodal processor kwargs with priority: + # Inference kwargs -> Init kwargs -> {} + # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = self._resolve_processor_kwargs( inputs, processor, model_config) From 7977ae88e44f00e58c1165565675b4ce99dbdd01 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 2 Oct 2024 06:11:04 -0400 Subject: [PATCH 04/18] Add mm processor kwargs to offline chat Signed-off-by: Alex-Brooks --- vllm/entrypoints/llm.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 439f3769f9fb..07e417c5ea40 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -503,6 +503,7 @@ def chat( add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -532,6 +533,8 @@ def chat( continue_final_message: If True, continues the final message in the conversation instead of starting a new one. Cannot be `True` if `add_generation_prompt` is also `True`. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. Returns: A list of ``RequestOutput`` objects containing the generated @@ -553,6 +556,9 @@ def chat( tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. conversation, mm_data = parse_chat_messages( msgs, model_config, tokenizer) @@ -585,6 +591,9 @@ def chat( if mm_data is not None: prompt["multi_modal_data"] = mm_data + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + prompts.append(prompt) return self.generate( From f132873272ff64869282faffcf64d6d7be4c6795 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 2 Oct 2024 17:54:01 -0400 Subject: [PATCH 05/18] Forward mm processor kwargs through sequence info -> mapper Signed-off-by: Alex-Brooks --- vllm/core/scheduler.py | 1 + vllm/engine/llm_engine.py | 7 +++++++ vllm/inputs/registry.py | 23 +++-------------------- vllm/multimodal/base.py | 21 ++++++++++++++++----- vllm/multimodal/registry.py | 13 +++++++++---- vllm/sequence.py | 14 ++++++++++++++ vllm/utils.py | 17 +++++++++++++++++ vllm/worker/cpu_model_runner.py | 8 +++++--- vllm/worker/model_runner.py | 4 +++- vllm/worker/neuron_model_runner.py | 5 ++++- vllm/worker/openvino_model_runner.py | 6 +++++- 11 files changed, 84 insertions(+), 35 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c57e6cd71640..1c7eaac91c17 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1308,6 +1308,7 @@ def schedule( # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6372d4b5d211..510ffac6f689 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -811,6 +811,13 @@ def add_request( ) processed_inputs = self.input_processor(preprocessed_inputs) + # This is a bit of a hack - copy the mm_processor_kwargs that were + # used in the input processor to the processed output, since these + # kwargs are presumed to be immutable and the values should be aligned + # between the input processor (here) and the input mapper. + processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( + "mm_processor_kwargs") + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6bd790f29e5a..36018bb6301d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -9,7 +9,8 @@ from typing_extensions import TypeVar from vllm.logger import init_logger -from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once +from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, + resolve_mm_processor_kwargs) from .data import LLMInputs @@ -296,7 +297,7 @@ def process_input(self, model_config: "ModelConfig", # Handle multimodal processor kwargs with priority: # Inference kwargs -> Init kwargs -> {} # If it's empty, it'll fall back to the default kwarg values - mm_processor_kwargs = self._resolve_processor_kwargs( + mm_processor_kwargs = resolve_mm_processor_kwargs( inputs, processor, model_config) return processor(InputContext(model_config), inputs, @@ -308,21 +309,3 @@ def create_input_processor(self, model_config: "ModelConfig"): specific model. """ return functools.partial(self.process_input, model_config) - - @staticmethod - def _resolve_processor_kwargs( - inputs: LLMInputs, processor: InputProcessor, - model_config: "ModelConfig") -> Dict[str, Any]: - - # Filter inference time multimodal processor kwargs provided - runtime_mm_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=inputs.get("mm_processor_kwargs")) - - # Filter init time multimodal processor kwargs provided - init_mm_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=model_config.mm_processor_kwargs) - - # Merge the final processor kwargs, prioritizing inference - # time values over the initialization time values. - mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} - return mm_processor_kwargs diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 8bcb38ef241e..b976768c20b7 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,7 +1,7 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, +from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypedDict, TypeVar, Union, cast, final) import numpy as np @@ -243,7 +243,8 @@ def wrapper(model_cls: N) -> N: return wrapper def map_input(self, model_config: ModelConfig, - data: MultiModalData[object]) -> MultiModalInputs: + data: MultiModalData[object], + mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. @@ -263,9 +264,19 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) - # Only get processor kwargs at mapping time if we are not using the - # input mapper; no overrides are used on the default here because they - # should be passed to the huggingface resource at initialization time. + + # There's a nasty edge-case here if the default mapper is being used + # and the underlying huggingface resource has init time kwargs that + # do not line up with its inference time kwargs - we probably need to + # warn with a fallback that rebuilds a mapper based on the model_cls + # to reinitialize the HF resource, otherwise things are pretty likely + # to crash with cryptic errors, like placeholder mismatches, from + # correctly handing it in the input processor and not in the input + # mapper. + if mm_processor_kwargs is not None: + raise NotImplementedError( + "TODO - need to implement runtime processor kwarg merging") + if mapper is not None and mapper != self._default_input_mapper: mm_processor_kwargs = get_allowed_kwarg_only_overrides( mapper, overrides=model_config.mm_processor_kwargs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3940e1671b57..5e9b8bd518de 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,6 +1,6 @@ import functools from collections import UserDict -from typing import Dict, Mapping, Optional, Sequence +from typing import Any, Dict, Mapping, Optional, Sequence from vllm.config import ModelConfig from vllm.logger import init_logger @@ -96,8 +96,12 @@ def register_image_input_mapper( """ return self.register_input_mapper("image", mapper) - def map_input(self, model_config: ModelConfig, - data: MultiModalDataDict) -> MultiModalInputs: + def map_input( + self, + model_config: ModelConfig, + data: MultiModalDataDict, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. @@ -123,7 +127,8 @@ def map_input(self, model_config: ModelConfig, f"`--limit-mm-per-prompt`, but found {num_items} items " "in the same prompt.") - input_dict = plugin.map_input(model_config, data_value) + input_dict = plugin.map_input(model_config, data_value, + mm_processor_kwargs) for input_key, input_tensor in input_dict.items(): if input_key in merged_dict: raise ValueError(f"The input mappers (keys={set(data)}) " diff --git a/vllm/sequence.py b/vllm/sequence.py index 9116408a001f..0c27ffca36cf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -481,6 +481,10 @@ def multi_modal_data(self) -> "MultiModalDataDict": EncoderDecoderLLMInputs, inputs).get("encoder_multi_modal_data")) or {} + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + return self.inputs.get("mm_processor_kwargs") or {} + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -710,6 +714,14 @@ def multi_modal_data(self) -> "MultiModalDataDict": # We use the multi-modal data of an arbitrary sequence. return self.seqs[0].multi_modal_data + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + # As with multi-modal data, all sequences in the group should have the + # same processor kwargs (i.e., mm_processor_kwargs are optionally + # provided per request; note that are independent of whether the model + # decoder-only or an encoder-decoder). + return self.seqs[0].mm_processor_kwargs + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -949,6 +961,7 @@ class SequenceGroupMetadata( used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + mm_processor_kwargs: Multimodal input processor / mapper overrides. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder @@ -975,6 +988,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/utils.py b/vllm/utils.py index 9c6f1a347fb8..9fd58bf521b5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1286,6 +1286,23 @@ def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: for param in params.values()) +def resolve_mm_processor_kwargs(inputs, processor, + model_config) -> Dict[str, Any]: + + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=inputs.get("mm_processor_kwargs")) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=model_config.mm_processor_kwargs) + + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs + + def get_allowed_kwarg_only_overrides( callable: Callable[..., object], overrides: Optional[Dict[str, Any]], diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a03c56253217..f67b08679641 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -148,8 +148,9 @@ def build(self) -> ModelInputForCPU: ) def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, - computed_len: int): - mm_kwargs = self.multi_modal_input_mapper(mm_data) + computed_len: int, + mm_processor_kwargs: Dict[str, Any]): + mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) # special processing for mrope position deltas. mrope_positions = None @@ -210,7 +211,8 @@ def _prepare_prompt( mrope_positions = None if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs, mrope_positions = self._compute_multi_modal_input( - seq_data, mm_data, computed_len) + seq_data, mm_data, computed_len, + seq_group_metadata.mm_processor_kwargs) multi_modal_inputs_list.append(mm_kwargs) # Token position ids diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 978443884198..0bd295881671 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -640,7 +640,9 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) inter_data.multi_modal_inputs = mm_kwargs # special processing for mrope position deltas. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 44d4845a838e..b8c760c4b539 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -153,7 +153,10 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, + ) multi_modal_inputs_list.append(mm_kwargs) max_seq_len = max(seq_lens) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 77ee2eadf29a..de3088695dfe 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -172,7 +172,11 @@ def _prepare_model_input( mm_data = seq_group_metadata.multi_modal_data if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) multi_modal_inputs_list.append(mm_kwargs) block_table = seq_group_metadata.block_tables[seq_id] From 7c0644f7bf6ee536dd0703c9639c8eb59272bfe4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 6 Oct 2024 06:43:08 -0400 Subject: [PATCH 06/18] Implement dynamic mm processor kwargs in mapper Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 5 ++++- vllm/multimodal/audio.py | 4 ++-- vllm/multimodal/base.py | 36 +++++++++++++++++------------------- vllm/multimodal/image.py | 24 ++++++++++++++++++------ vllm/multimodal/video.py | 24 +++++++++++++++++------- vllm/utils.py | 19 ++++++++++++------- 6 files changed, 70 insertions(+), 42 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 36018bb6301d..5bd3e1c86f66 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -298,7 +298,10 @@ def process_input(self, model_config: "ModelConfig", # Inference kwargs -> Init kwargs -> {} # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( - inputs, processor, model_config) + model_config.mm_processor_kwargs, + inputs.get("mm_processor_kwargs"), + processor, + ) return processor(InputContext(model_config), inputs, **mm_processor_kwargs) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index b4bf4b4541db..04d71826f29f 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: + def _default_input_mapper(self, ctx: InputContext, data: object, + **mm_processor_kwargs) -> MultiModalInputs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index b976768c20b7..d36f36a932ae 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -15,7 +15,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, - json_map_leaves) + json_map_leaves, resolve_mm_processor_kwargs) logger = init_logger(__name__) @@ -200,6 +200,7 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: """ Return a dictionary to be passed as keyword arguments to @@ -265,28 +266,25 @@ def map_input(self, model_config: ModelConfig, mapper = self._input_mappers.get(model_cls) - # There's a nasty edge-case here if the default mapper is being used - # and the underlying huggingface resource has init time kwargs that - # do not line up with its inference time kwargs - we probably need to - # warn with a fallback that rebuilds a mapper based on the model_cls - # to reinitialize the HF resource, otherwise things are pretty likely - # to crash with cryptic errors, like placeholder mismatches, from - # correctly handing it in the input processor and not in the input - # mapper. - if mm_processor_kwargs is not None: - raise NotImplementedError( - "TODO - need to implement runtime processor kwarg merging") - - if mapper is not None and mapper != self._default_input_mapper: - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - mapper, overrides=model_config.mm_processor_kwargs) - else: - mm_processor_kwargs = {} - if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") + # In the case of the default mapper, we have to get resource + # processor through its HuggingFace autoclass; since this goes + # through **kwargs, we can't inspect it the same way, so we allow + # drop mm_processor_kwargs based on signature inspection + # if we're using the default mapper. + # + # NOTE: In the future, adding retry logic of some kind might be a + # good idea, especially if this interface is exposed through + # the server somehow. + uses_default_mapper = mapper == self._default_input_mapper + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + mm_processor_kwargs, + callable=None if uses_default_mapper else mapper, + ) return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 7ca64152e481..5f74bcea65ce 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Any, Dict, Optional import torch from PIL import Image @@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig): - mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None - else model_config.mm_processor_kwargs) - # We don't explicitly check kwarg overrides to the HF class - # since the automodel just takes kwargs, so we can't inspect it + def _get_hf_image_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -37,6 +40,7 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: model_config = ctx.model_config @@ -46,12 +50,20 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor(model_config) + image_processor = self._get_hf_image_processor( + model_config, + mm_processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") try: + # NOTE: It may make sense to forward the mm_processor_kwargs + # here too. For now, to keep it simple, we only allow it be + # used for the initialization call though, just in case the + # signatures of the preprocessor initializer don't match + # preprocess() batch_data = image_processor \ .preprocess(data, return_tensors="pt") \ .data diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 39e75dbaf687..4a9dbf20c8ec 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import List, Union +from typing import Any, Dict, List, Optional, Union import numpy as np @@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin): def get_data_key(self) -> str: return "video" - def _get_hf_video_processor(self, model_config: ModelConfig): - mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None - else model_config.mm_processor_kwargs) - # We don't explicitly check kwarg overrides to the HF class - # since the automodel just takes kwargs, so we can't inspect it + def _get_hf_video_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} return cached_get_video_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -50,16 +52,24 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: model_config = ctx.model_config # single video input as np.ndarray if isinstance(data, np.ndarray): - video_processor = self._get_hf_video_processor(model_config) + video_processor = self._get_hf_video_processor( + model_config, + mm_processor_kwargs, + ) if video_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") try: + # NOTE: Similar to image; it may be a good idea to filter and + # pass mm_processor_kwargs here too, but for now we don't to + # avoid extra complexity if the initializer and preprocess + # signatures of the processor don't align batch_data = video_processor(data, return_tensors="pt").data except Exception: logger.error("Failed to process image (%s)", data) diff --git a/vllm/utils.py b/vllm/utils.py index 9fd58bf521b5..d4aff1618eb9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1286,16 +1286,15 @@ def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: for param in params.values()) -def resolve_mm_processor_kwargs(inputs, processor, - model_config) -> Dict[str, Any]: - +def resolve_mm_processor_kwargs(init_kwargs, inference_kwargs, + callable) -> Dict[str, Any]: # Filter inference time multimodal processor kwargs provided runtime_mm_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=inputs.get("mm_processor_kwargs")) + callable, overrides=inference_kwargs) # Filter init time multimodal processor kwargs provided - init_mm_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=model_config.mm_processor_kwargs) + init_mm_kwargs = get_allowed_kwarg_only_overrides(callable, + overrides=init_kwargs) # Merge the final processor kwargs, prioritizing inference # time values over the initialization time values. @@ -1304,7 +1303,7 @@ def resolve_mm_processor_kwargs(inputs, processor, def get_allowed_kwarg_only_overrides( - callable: Callable[..., object], + callable: Optional[Callable[..., object]], overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: """ @@ -1317,6 +1316,7 @@ def get_allowed_kwarg_only_overrides( Args: callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. overrides: Potential overrides to be used when invoking the callable. Returns: @@ -1327,6 +1327,11 @@ def get_allowed_kwarg_only_overrides( if not overrides: return {} + # In some situations, the real callable might be wrapped, e.g., the init of + # a class received through a HF auto class. In such cases, allow anything. + if callable is None: + return overrides + allowed_override_names = [ name for name, param in inspect.signature(callable).parameters.items() if param.kind == inspect.Parameter.KEYWORD_ONLY From 12379d45ae017a21243d473b43f0697eca85ff5c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 6 Oct 2024 06:43:39 -0400 Subject: [PATCH 07/18] Add happy path tests for inference time mm proc kwargs Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor_kwargs.py | 98 ++++++++++++++--------- 1 file changed, 62 insertions(+), 36 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5529ccd4fa57..85db5f969616 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -74,11 +74,11 @@ def mm_model_cls(): # lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { - "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) + "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) } -### Test for default processor logic & mm_processor_kwargs wrapping +### Tests for default processor logic & mm_processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() @@ -89,23 +89,39 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_processor_default_kwargs(use_processor_mock, num_crops): +# @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ + (None, None), + (NUM_CROPS_OVERRIDE, None), + (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), +]) +def test_input_processor_kwargs(use_processor_mock, init_num_crops, + inference_num_crops): """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + init_kwargs = None if init_num_crops is None else { + "num_crops": init_num_crops } - expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - ctx = build_model_context(DUMMY_MODEL_ID, - mm_processor_kwargs=mm_processor_kwargs) + inference_kwargs = None if inference_num_crops is None else { + "num_crops": inference_num_crops + } + if inference_num_crops is not None: + expected_seq_count = inference_num_crops + elif init_num_crops is not None: + expected_seq_count = init_num_crops + else: + expected_seq_count = DEFAULT_NUM_CROPS + + ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) - assert num_crops_val == expected_num_crops + num_crops_val = processor( + LLMInputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) + assert num_crops_val == expected_seq_count @pytest.mark.parametrize( @@ -271,32 +287,45 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_custom_mapper_kwarg_overrides(image_assets, num_crops): +@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ + (None, None), + (NUM_CROPS_OVERRIDE, None), + (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), +]) +def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, + inference_num_crops): """Ensure custom mappers can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + init_kwargs = None if init_num_crops is None else { + "num_crops": init_num_crops } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + inference_kwargs = None if inference_num_crops is None else { + "num_crops": inference_num_crops + } + # Priority: inference -> init -> model config + if inference_num_crops is not None: + expected_seq_count = inference_num_crops + elif init_num_crops is not None: + expected_seq_count = init_num_crops + else: + expected_seq_count = DEFAULT_NUM_CROPS + ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, + mm_processor_kwargs=init_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) - # Patch the image registry for phi3v with our lambda that is compatible - # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} - with patch.object( - mm_registry._get_plugin("image"), - "_default_input_mapper", - {mm_model_cls(): custom_mapper}, - ): - mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( + mm_model_cls()) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, + inference_kwargs) assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 @@ -323,17 +352,14 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) - # Patch the image registry for phi3v with our lambda that is compatible - # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} - with patch.object( - mm_registry._get_plugin("image"), - "_default_input_mapper", - {mm_model_cls(): custom_mapper}, - ): - mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( + mm_model_cls()) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From b35d18a750253f5c92b5fcdb2762847a5284dd3a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 6 Oct 2024 06:50:50 -0400 Subject: [PATCH 08/18] Add inference time kwargs to sad path tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor_kwargs.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 85db5f969616..76e600969547 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -140,11 +140,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, mm_processor_kwargs): """Ensure that input processors filter out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() + # Should filter out the init time kwargs ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + # Should filter out the inference time kwargs + num_crops_val = processor( + LLMInputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=mm_processor_kwargs)) assert num_crops_val == DEFAULT_NUM_CROPS @@ -345,6 +350,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, def test_custom_mapper_with_sad_kwarg_overrides(image_assets, mm_processor_kwargs): """Ensure that custom mappers filters out invalid mm_processor_kwargs""" + # Should filter out the init time kwargs ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, @@ -360,6 +366,8 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, # our num_crops value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) - mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Should filter out the inference time kwargs + mapped_inputs = mm_registry.map_input( + ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From e71a3f47505aa933a589b799a9e203c2bb5ea4eb Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 6 Oct 2024 14:48:59 -0400 Subject: [PATCH 09/18] Update example to add comment for per request processor kwargs Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index b94ef537d783..b5232b05e2b9 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -105,6 +105,7 @@ def run_phi3v(question, modality): trust_remote_code=True, max_model_len=4096, max_num_seqs=2, + # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None From 00fbdb8183da08789c3d0a65bffd1745f3beb4a7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 00:24:12 -0400 Subject: [PATCH 10/18] Removed old test parametrize Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor_kwargs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 76e600969547..f5fad69374eb 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -89,7 +89,6 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -# @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) @pytest.mark.parametrize("init_num_crops,inference_num_crops", [ (None, None), (NUM_CROPS_OVERRIDE, None), From d977ce27ec3c591995f58e3ba323e51ac1b3e322 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 00:34:25 -0400 Subject: [PATCH 11/18] Refactor mm processor kwarg tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor_kwargs.py | 41 +++++++++++------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index f5fad69374eb..efc6903c373b 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -89,15 +89,8 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_input_processor_kwargs(use_processor_mock, init_num_crops, - inference_num_crops): - """Ensure input processors can use processor kwargs.""" - dummy_registry = InputRegistry() +def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): + """Get the init / inference kwargs and expected num_crops for this test.""" # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value @@ -113,6 +106,21 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, expected_seq_count = init_num_crops else: expected_seq_count = DEFAULT_NUM_CROPS + return init_kwargs, inference_kwargs, expected_seq_count + + +@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ + (None, None), + (NUM_CROPS_OVERRIDE, None), + (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), +]) +def test_input_processor_kwargs(use_processor_mock, init_num_crops, + inference_num_crops): + """Ensure input processors can use processor kwargs.""" + dummy_registry = InputRegistry() + + init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( + init_num_crops, inference_num_crops) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) @@ -299,19 +307,8 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, inference_num_crops): """Ensure custom mappers can use processor kwargs.""" - init_kwargs = None if init_num_crops is None else { - "num_crops": init_num_crops - } - inference_kwargs = None if inference_num_crops is None else { - "num_crops": inference_num_crops - } - # Priority: inference -> init -> model config - if inference_num_crops is not None: - expected_seq_count = inference_num_crops - elif init_num_crops is not None: - expected_seq_count = init_num_crops - else: - expected_seq_count = DEFAULT_NUM_CROPS + init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( + init_num_crops, inference_num_crops) ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, From 03f0ea011e76823ea345b0ca26411268c40ad458 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 00:52:50 -0400 Subject: [PATCH 12/18] Add missing docstring to mm proc kwarg merger Signed-off-by: Alex-Brooks --- vllm/utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index d4aff1618eb9..b8fbd907a118 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1286,8 +1286,19 @@ def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: for param in params.values()) -def resolve_mm_processor_kwargs(init_kwargs, inference_kwargs, - callable) -> Dict[str, Any]: +def resolve_mm_processor_kwargs( + init_kwargs: Optional[Dict[str, Any]], + inference_kwargs: Optional[Dict[str, Any]], + callable: Optional[Callable], +) -> Dict[str, Any]: + """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., + those who are not explicit keywords to the given callable (of one is + given; otherwise no filtering is done), then merges the kwarg dicts, + giving priority to inference_kwargs if there are any collisions. + + In the case that no kwarg overrides are provided, returns an empty + dict so that it can still be kwarg expanded into the callable later on. + """ # Filter inference time multimodal processor kwargs provided runtime_mm_kwargs = get_allowed_kwarg_only_overrides( callable, overrides=inference_kwargs) From 21ac0822b3262172eb436764f568945393b004b4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 04:40:57 -0400 Subject: [PATCH 13/18] Add mm processor kwargs to zip utils for enc/dec Signed-off-by: Alex-Brooks --- vllm/inputs/data.py | 43 +++++++++++++++++++++++++++++++-------- vllm/inputs/preprocess.py | 35 +++++++++---------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 79f9fc45c1a2..724cdd2e6e80 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -19,7 +19,7 @@ class TextPrompt(TypedDict): if the model supports it. """ - mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -40,7 +40,7 @@ class TokensPrompt(TypedDict): if the model supports it. """ - mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -90,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): according to any of the :class:`SingletonPrompt` schemas, and are not required to have the same schema. - Only the encoder prompt may have multi-modal data. + Only the encoder prompt may have multi-modal data. mm_processor_kwargs + should be at the top-level, and should not be set in the encoder/decoder + prompts, since they are agnostic to the encoder/decoder. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, @@ -103,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] + mm_processor_kwargs: NotRequired[Dict[str, Any]] + PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ @@ -176,22 +180,43 @@ class EncoderDecoderLLMInputs(LLMInputs): def build_explicit_enc_dec_prompt( encoder_prompt: _T1, decoder_prompt: Optional[_T2], + mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: - return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, - decoder_prompt=decoder_prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt, + mm_processor_kwargs=mm_processor_kwargs) def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], dec_prompts: Iterable[Optional[_T2]], + mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]], + Dict[str, Any]]] = None, ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of - :class:`ExplicitEncoderDecoderPrompt` instances. - """ + :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs + may also be provided; if a dict is passed, the same dictionary will be + used for every encoder/decoder prompt. If an iterable is provided, it will + be zipped with the encoder/decoder prompts. + """ + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + if isinstance(mm_processor_kwargs, Dict): + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_processor_kwargs) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) + ] return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) - for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs + ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) ] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 97c6168ff20d..07a5bbadee6c 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -386,17 +386,18 @@ def _process_encoder_decoder_prompt( decoder_input, request_id=request_id, ) + mm_processor_kwargs = prompt["mm_processor_kwargs"] else: encoder_comps = self._extract_prompt_components( prompt, request_id=request_id, ) - + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} decoder_comps = None, None, None, None - mm_processor_kwargs = self._get_encoder_decoder_processor_kwargs( - prompt, encoder_comps, decoder_comps) - return self._build_enc_dec_llm_inputs( encoder_comps, decoder_comps, @@ -429,40 +430,24 @@ async def _process_encoder_decoder_prompt_async( encoder_comps, decoder_comps = await asyncio.gather( encoder_task, decoder_task) + mm_processor_kwargs = prompt["mm_processor_kwargs"] else: encoder_comps = await self._extract_prompt_components_async( prompt, request_id=request_id, ) - + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} decoder_comps = None, None, None, None - mm_processor_kwargs = self._get_encoder_decoder_processor_kwargs( - prompt, encoder_comps, decoder_comps) return self._build_enc_dec_llm_inputs( encoder_comps, decoder_comps, mm_processor_kwargs, ) - @staticmethod - def _get_encoder_decoder_processor_kwargs(prompt, encoder_comps, - decoder_comps): - mm_processor_kwargs = prompt.get("mm_processor_kwargs") - # Because of the common logic and types with decoder-only models, - # mm_processor_kwargs can technically be passed to individual prompts; - # users should instead pass these at the top level of the prompt, since - # the mm mapper/processor are component agnostic. - enc_kwargs = encoder_comps[-1] - dec_kwargs = decoder_comps[-1] - if enc_kwargs is not None or dec_kwargs is not None: - logger.warning( - "mm_processor_kwargs are encoder / decoder agnostic ", - "and should be passed as a top-level option for ", - "explicit encoder / decoder prompts; the provided ", - "values will not be used.") - return mm_processor_kwargs - def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, From 2bea3636d4c7a7f3aca7283990b9b0bcea0fec94 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 04:59:31 -0400 Subject: [PATCH 14/18] Add test for zip_enc_dec_prompts Signed-off-by: Alex-Brooks --- tests/test_inputs.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 3725d8687f25..0093e3179e66 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2,6 +2,7 @@ import pytest +from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_and_batch_prompt STRING_INPUTS = [ @@ -51,3 +52,40 @@ def test_parse_single_batch_token_consistent(token_input: List[int]): def test_parse_single_batch_string_slice(inputs_slice: slice): assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) + + +@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [ + (None, [{}, {}]), + ({}, [{}, {}]), + ({ + "foo": 100 + }, [{ + "foo": 100 + }, { + "foo": 100 + }]), + ([{ + "foo": 100 + }, { + "bar": 200 + }], [{ + "foo": 100 + }, { + "bar": 200 + }]), +]) +def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): + """Test mm_processor_kwargs init for zipping enc/dec prompts.""" + encoder_prompts = ['An encoder prompt', 'Another encoder prompt'] + decoder_prompts = ['A decoder prompt', 'Another decoder prompt'] + zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts, + mm_processor_kwargs) + assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts) + for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts, + expected_mm_kwargs, + zipped_prompts): + assert isinstance(zipped, dict) + assert len(zipped.keys()) == 3 + assert zipped['encoder_prompt'] == enc + assert zipped['decoder_prompt'] == dec + assert zipped['mm_processor_kwargs'] == exp_kwargs From 6c71243e30273526e2220b7af0d745b0b6c52fb8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 7 Oct 2024 06:09:56 -0400 Subject: [PATCH 15/18] Fix keyerror in encoder/decoder prompt preproc Signed-off-by: Alex-Brooks --- vllm/inputs/preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 07a5bbadee6c..22adb1631d41 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -386,7 +386,8 @@ def _process_encoder_decoder_prompt( decoder_input, request_id=request_id, ) - mm_processor_kwargs = prompt["mm_processor_kwargs"] + # Handle this carefully in case it was directly initialized by user + mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) else: encoder_comps = self._extract_prompt_components( prompt, From 843702700d90e28d843ec7a4f79be60e472a8ace Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 8 Oct 2024 03:39:47 -0400 Subject: [PATCH 16/18] Disable formatting in zip enc dec prompt tests Signed-off-by: Alex-Brooks --- tests/test_inputs.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 0093e3179e66..fff7c5fc0428 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -54,26 +54,14 @@ def test_parse_single_batch_string_slice(inputs_slice: slice): == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) +# yapf: disable @pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [ (None, [{}, {}]), ({}, [{}, {}]), - ({ - "foo": 100 - }, [{ - "foo": 100 - }, { - "foo": 100 - }]), - ([{ - "foo": 100 - }, { - "bar": 200 - }], [{ - "foo": 100 - }, { - "bar": 200 - }]), + ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), + ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), ]) +# yapf: enable def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): """Test mm_processor_kwargs init for zipping enc/dec prompts.""" encoder_prompts = ['An encoder prompt', 'Another encoder prompt'] From 3fb3b616f2b33b6339de9fa184f623089bcde7a1 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 8 Oct 2024 03:49:32 -0400 Subject: [PATCH 17/18] Fix hack for default mapper signature inspection Signed-off-by: Alex-Brooks --- vllm/multimodal/base.py | 8 ++-- vllm/utils.py | 82 +++++++++++++++++++++++++++++------------ 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index d36f36a932ae..84e71cbf60df 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -276,14 +276,14 @@ def map_input(self, model_config: ModelConfig, # drop mm_processor_kwargs based on signature inspection # if we're using the default mapper. # - # NOTE: In the future, adding retry logic of some kind might be a - # good idea, especially if this interface is exposed through - # the server somehow. + # This should be safe in general due to the sanitation, since the + # transformers resource should filter unused kwargs anyway. uses_default_mapper = mapper == self._default_input_mapper mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, mm_processor_kwargs, - callable=None if uses_default_mapper else mapper, + callable=mapper, + allow_var_kwargs=uses_default_mapper, ) return mapper(InputContext(model_config), data, **mm_processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index b8fbd907a118..98e7c25593cc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1277,19 +1277,54 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) -def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: +def supports_kw( + callable: Callable[..., object], + kw_name: str, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ params = inspect.signature(callable).parameters - if kw_name in params: - return True + if not params: + return False - return any(param.kind == inspect.Parameter.VAR_KEYWORD - for param in params.values()) + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if (requires_kw_only and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + return False + if ((requires_kw_only + and param_val.kind == inspect.Parameter.KEYWORD_ONLY) + or (not requires_kw_only and is_sig_param)): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return (last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name) + return False def resolve_mm_processor_kwargs( init_kwargs: Optional[Dict[str, Any]], inference_kwargs: Optional[Dict[str, Any]], - callable: Optional[Callable], + callable: Callable[..., object], + allow_var_kwargs: bool = False, ) -> Dict[str, Any]: """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., those who are not explicit keywords to the given callable (of one is @@ -1298,14 +1333,20 @@ def resolve_mm_processor_kwargs( In the case that no kwarg overrides are provided, returns an empty dict so that it can still be kwarg expanded into the callable later on. + + If allow_var_kwargs=True, allows for things that can be expanded into + kwargs as long as they aren't naming collision for var_kwargs or potential + positional arguments. """ # Filter inference time multimodal processor kwargs provided runtime_mm_kwargs = get_allowed_kwarg_only_overrides( - callable, overrides=inference_kwargs) + callable, + overrides=inference_kwargs, + allow_var_kwargs=allow_var_kwargs) # Filter init time multimodal processor kwargs provided - init_mm_kwargs = get_allowed_kwarg_only_overrides(callable, - overrides=init_kwargs) + init_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) # Merge the final processor kwargs, prioritizing inference # time values over the initialization time values. @@ -1314,8 +1355,9 @@ def resolve_mm_processor_kwargs( def get_allowed_kwarg_only_overrides( - callable: Optional[Callable[..., object]], + callable: Callable[..., object], overrides: Optional[Dict[str, Any]], + allow_var_kwargs: bool = False, ) -> Dict[str, Any]: """ Given a callable which has one or more keyword only params and a dict @@ -1329,6 +1371,7 @@ def get_allowed_kwarg_only_overrides( callable: Callable which takes 0 or more keyword only arguments. If None is provided, all overrides names are allowed. overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. Returns: Dictionary containing the kwargs to be leveraged which may be used @@ -1338,22 +1381,15 @@ def get_allowed_kwarg_only_overrides( if not overrides: return {} - # In some situations, the real callable might be wrapped, e.g., the init of - # a class received through a HF auto class. In such cases, allow anything. - if callable is None: - return overrides - - allowed_override_names = [ - name for name, param in inspect.signature(callable).parameters.items() - if param.kind == inspect.Parameter.KEYWORD_ONLY - ] - - # Drop any mm_processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if kwarg_name in allowed_override_names + if supports_kw(callable, + kwarg_name, + requires_kw_only=True, + allow_var_kwargs=allow_var_kwargs) } # If anything is dropped, log a warning From f8ebc9f4fe549c4979abaae4b51d5ee162054959 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 8 Oct 2024 03:49:48 -0400 Subject: [PATCH 18/18] Add supports_kw tests Signed-off-by: Alex-Brooks --- tests/test_utils.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f3017a8582ea..268e6f8194ab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,7 @@ import pytest from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, - get_open_port, merge_async_iterators) + get_open_port, merge_async_iterators, supports_kw) from .utils import error_on_warning @@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config): with pytest.raises(ValueError): parser_with_config.parse_args( ['serve', '--config', './data/test_config.yaml']) + + +# yapf: enable +@pytest.mark.parametrize( + "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", + [ + # Tests for positional argument support + (lambda foo: None, "foo", True, True, False), + (lambda foo: None, "foo", False, True, True), + # Tests for positional or keyword / keyword only + (lambda foo=100: None, "foo", True, True, False), + (lambda *, foo: None, "foo", False, True, True), + # Tests to make sure the names of variadic params are NOT supported + (lambda *args: None, "args", False, True, False), + (lambda **kwargs: None, "kwargs", False, True, False), + # Tests for if we allow var kwargs to add support + (lambda foo: None, "something_else", False, True, False), + (lambda foo, **kwargs: None, "something_else", False, True, True), + (lambda foo, **kwargs: None, "kwargs", True, True, False), + (lambda foo, **kwargs: None, "foo", True, True, False), + ]) +# yapf: disable +def test_supports_kw(callable,kw_name,requires_kw_only, + allow_var_kwargs,is_supported): + assert supports_kw( + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs + ) == is_supported