From 550378be96448232a1e47f4a3af774fc3eca3cb7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 16 Sep 2024 17:29:14 -0400 Subject: [PATCH 01/35] Allow for processor kwarg overrides Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 21 ++++++++++ vllm/config.py | 6 ++- vllm/engine/arg_utils.py | 8 ++++ vllm/engine/llm_engine.py | 3 +- vllm/entrypoints/llm.py | 2 + vllm/inputs/registry.py | 71 ++++++++++++++++++++++++++++++---- 6 files changed, 102 insertions(+), 9 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 8dd200b35d0f..fabf37aa2a68 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_bad_nullable_kvs(arg): with pytest.raises(ArgumentTypeError): nullable_kvs(arg) + + +@pytest.mark.parametrize(("arg", "expected"), [ + (None, None), + ("{}", {}), + ('{"num_crops": 4}', { + "num_crops": 4 + }), + ('{"foo": {"bar": "baz"}}', { + "foo": { + "bar": "baz" + } + }), +]) +def test_processor_kwargs_prompt_parser(arg, expected): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + if arg is None: + args = parser.parse_args([]) + else: + args = parser.parse_args(["--processor-kwargs", arg]) + assert args.processor_kwargs == expected diff --git a/vllm/config.py b/vllm/config.py index 7a15606836dc..94552a22cc25 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,6 +122,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + processor_kwargs: Arguments to be forwarded to the model's processor, + e.g., tokenizer, image processor, or custom processor callable. """ def __init__(self, @@ -150,7 +152,8 @@ def __init__(self, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, - config_format: ConfigFormat = ConfigFormat.AUTO) -> None: + config_format: ConfigFormat = ConfigFormat.AUTO, + processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -184,6 +187,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.processor_kwargs = processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4139eca9c183..ca1f334de535 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,6 +175,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None + processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -513,6 +514,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'e.g.: `image=16,video=2` allows a maximum of 16 ' 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) + parser.add_argument( + '--processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the model processor, e.g., tokenizer or ' + 'image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -822,6 +829,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, + processor_kwargs=self.processor_kwargs, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2743d5c7d228..a482cbbe2009 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -235,7 +235,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s)", + "use_async_output_proc=%s, processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -268,6 +268,7 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + model_config.processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd..6304851233ce 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,6 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + processor_kwargs=None, **kwargs, ) -> None: ''' @@ -174,6 +175,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + processor_kwargs=processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f7..eb816baa6e8c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,4 +1,5 @@ import functools +import inspect from array import array from collections import UserDict from dataclasses import dataclass @@ -245,6 +246,34 @@ def process_input(self, model_config: "ModelConfig", See also: :ref:`input_processing_pipeline` """ + processor = self._get_model_input_processor(model_config) + return processor(InputContext(model_config), inputs) + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + # Determine which kwargs can be leveraged for the input processor + # and drop + warn for kwargs that are unimplemented. + processor_kwargs = self._get_allowed_kwarg_overrides( + callable=self._get_model_input_processor(model_config), + overrides=model_config.processor_kwargs, + ) + return functools.partial(self.process_input, model_config, + **processor_kwargs) + + def _get_model_input_processor(self, + model_config: "ModelConfig") -> Callable: + """Grabs the input processor for the provided model. + + Args: + model_config: Config whose model architecture we can leverage to + grab the callable input processor. + + Returns: + Callable input processor for this model. + """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture @@ -252,12 +281,40 @@ def process_input(self, model_config: "ModelConfig", processor = self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) + return processor - return processor(InputContext(model_config), inputs) - - def create_input_processor(self, model_config: "ModelConfig"): - """ - Create an input processor (see :meth:`process_input`) for a - specific model. + def _get_allowed_kwarg_overrides( + self, + callable: Callable, + overrides: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. """ - return functools.partial(self.process_input, model_config) + if not isinstance(overrides, dict): + return {} + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + return filtered_overrides From 190606f4d619de75b81af9ff6bba5031b5837393 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:29:40 -0400 Subject: [PATCH 02/35] Pass processor through to partial Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index eb816baa6e8c..55d3aaec271f 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -212,7 +212,7 @@ def _default_input_processor(self, ctx: InputContext, """The default input processor is a no-op.""" return inputs - def register_input_processor(self, processor: InputProcessor): + def register_input_processor(self, processor: InputProcessor) -> Callable: """ Register an input processor to a model class. @@ -236,36 +236,42 @@ def wrapper(model_cls: N) -> N: return wrapper - def process_input(self, model_config: "ModelConfig", - inputs: LLMInputs) -> LLMInputs: + def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", + processor: Callable, **processor_kwargs) -> LLMInputs: """ - Apply an input processor to an instance of model inputs. + Apply an input processor to an instance of model inputs. This will + usually not be invoked be directly, and instead will be wrapped in + a functools partial once the processor is created. The model is identified by ``model_config``. See also: :ref:`input_processing_pipeline` """ - processor = self._get_model_input_processor(model_config) - return processor(InputContext(model_config), inputs) + return processor(InputContext(model_config), inputs, + **processor_kwargs) - def create_input_processor(self, model_config: "ModelConfig"): + def create_input_processor(self, model_config: "ModelConfig") -> Callable: """ - Create an input processor (see :meth:`process_input`) for a + Create an input processor (see :meth:`_process_input`) for a specific model. """ # Determine which kwargs can be leveraged for the input processor # and drop + warn for kwargs that are unimplemented. + processor = self._get_model_input_processor(model_config) processor_kwargs = self._get_allowed_kwarg_overrides( - callable=self._get_model_input_processor(model_config), + callable=processor, overrides=model_config.processor_kwargs, ) - return functools.partial(self.process_input, model_config, + return functools.partial(self._process_input, + model_config=model_config, + processor=processor, **processor_kwargs) def _get_model_input_processor(self, model_config: "ModelConfig") -> Callable: - """Grabs the input processor for the provided model. + """ + Grabs the input processor for the provided model. Args: model_config: Config whose model architecture we can leverage to @@ -288,7 +294,8 @@ def _get_allowed_kwarg_overrides( callable: Callable, overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: - """Given a callable processor, determine which kwarg overrides provided + """ + Given a callable processor, determine which kwarg overrides provided via the model config are valid keyword arguments, and drop any that are not. From b1ca0417dcd0548dd7a86e0a9f5f2db5a218f1c6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:30:46 -0400 Subject: [PATCH 03/35] Add default & processor kwarg override tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 91 ++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/multimodal/test_processor.py diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py new file mode 100644 index 000000000000..f7ecb0b01ddb --- /dev/null +++ b/tests/multimodal/test_processor.py @@ -0,0 +1,91 @@ +import pytest +from vllm.inputs.registry import InputRegistry +from vllm.config import ModelConfig +from unittest.mock import patch +from vllm.inputs import InputContext, LLMInputs + +DUMMY_MODEL_ID = "facebook/opt-125m" +# For processor kwargs - we test overrides by defining a callable with a +# default for the `num_crops`, then override the value through the processor +# kwargs +DEFAULT_NUM_CROPS = 4 +NUM_CROPS_OVERRIDE = 16 + +@pytest.fixture +def processor_mock(): + """Patches the internal model input processor with an override callable.""" + def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, num_crops=DEFAULT_NUM_CROPS): + # For testing purposes, we don't worry about the llm inputs / return + # type validation, and just return the value of the kwarg that we + # clobber. + return num_crops + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): + yield + + +def get_model_config(processor_kwargs=None): + """Creates a handle to a model config, which may have processor kwargs.""" + # NOTE - values / architecture don't matter too much here since we patch + # the return values for stuff like the input processor anyway. + return ModelConfig( + DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs + ) + + +def test_default_processor_is_a_noop(): + """Ensure that by default, there is no processor override.""" + dummy_registry = InputRegistry() + model_config = get_model_config() + processor = dummy_registry.create_input_processor(model_config) + proc_inputs = LLMInputs(prompt="foobar") + proc_outputs = processor(inputs=proc_inputs) + # We should get the same object back since this is a no-op by default + assert proc_inputs is proc_outputs + + + +def test_processor_default_kwargs(processor_mock): + """Ensure we can call a processor that has extra kwargs & no overrides.""" + dummy_registry = InputRegistry() + model_config = get_model_config() + processor = dummy_registry.create_input_processor(model_config) + # The patched fixture patches the processor to return the value of + # num_crops in the processor call, which should be 4 by default. + num_crops_val = processor(LLMInputs(prompt="foobar")) + assert num_crops_val == DEFAULT_NUM_CROPS + + +def test_processor_default_kwargs_with_override(processor_mock): + """Ensure we can call a processor that has extra kwargs & no overrides.""" + dummy_registry = InputRegistry() + # Create processor_kwargs to override the value used + # for num_crops in the patched processor callable + model_config = get_model_config( + processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE} + ) + processor = dummy_registry.create_input_processor(model_config) + num_crops_val = processor(LLMInputs(prompt="foobar")) + # Since the patched processor is an echo, we should get the + # override value we passed to processor_kwargs instead. + assert num_crops_val == NUM_CROPS_OVERRIDE + + +def test_processor_with_sad_kwarg_overrides(processor_mock): + """Ensure that processor kwargs that are unused do not fail.""" + dummy_registry = InputRegistry() + # Since the processor does not take `does_not_exist` as an arg, + # it will be filtered, then warn + drop it from the callable + # to prevent the processor from failing. + model_config = get_model_config( + processor_kwargs={"does_not_exist": 100}, + ) + + processor = dummy_registry.create_input_processor(model_config) + num_crops_val = processor(LLMInputs(prompt="foobar")) + assert num_crops_val == DEFAULT_NUM_CROPS From 195e31ccdcf93d871a0acce3d354c7ff82dcc98d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:49:54 -0400 Subject: [PATCH 04/35] Don't allow ctx or inputs as kwargs Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 52 +++++++++++++++++++----------- vllm/inputs/registry.py | 16 ++++++++- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index f7ecb0b01ddb..6a9f88be50b4 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest -from vllm.inputs.registry import InputRegistry + from vllm.config import ModelConfig -from unittest.mock import patch from vllm.inputs import InputContext, LLMInputs +from vllm.inputs.registry import InputRegistry DUMMY_MODEL_ID = "facebook/opt-125m" # For processor kwargs - we test overrides by defining a callable with a @@ -11,15 +13,21 @@ DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 + @pytest.fixture def processor_mock(): """Patches the internal model input processor with an override callable.""" - def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, num_crops=DEFAULT_NUM_CROPS): + + def custom_processor(ctx: InputContext, + llm_inputs: LLMInputs, + num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return # type validation, and just return the value of the kwarg that we # clobber. return num_crops - with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): + + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", + return_value=custom_processor): yield @@ -27,15 +35,13 @@ def get_model_config(processor_kwargs=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch # the return values for stuff like the input processor anyway. - return ModelConfig( - DUMMY_MODEL_ID, - DUMMY_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs - ) + return ModelConfig(DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs) def test_default_processor_is_a_noop(): @@ -49,7 +55,6 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs - def test_processor_default_kwargs(processor_mock): """Ensure we can call a processor that has extra kwargs & no overrides.""" dummy_registry = InputRegistry() @@ -67,8 +72,7 @@ def test_processor_default_kwargs_with_override(processor_mock): # Create processor_kwargs to override the value used # for num_crops in the patched processor callable model_config = get_model_config( - processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE} - ) + processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE}) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt="foobar")) # Since the patched processor is an echo, we should get the @@ -82,10 +86,20 @@ def test_processor_with_sad_kwarg_overrides(processor_mock): # Since the processor does not take `does_not_exist` as an arg, # it will be filtered, then warn + drop it from the callable # to prevent the processor from failing. - model_config = get_model_config( - processor_kwargs={"does_not_exist": 100}, - ) + model_config = get_model_config(processor_kwargs={"does_not_exist": 100}, ) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt="foobar")) assert num_crops_val == DEFAULT_NUM_CROPS + + +def test_processor_kwargs_cannot_clobber_reserved_kwargs(processor_mock): + """Ensure that special kwargs cannot be overridden.""" + dummy_registry = InputRegistry() + model_config = get_model_config(processor_kwargs={"ctx": + "something bad"}, ) + processor = dummy_registry.create_input_processor(model_config) + # It's good enough to make sure this is callable, because if we had + # an override pushed through, we'd run into issues with multiple + # values provided for a single argument + processor(LLMInputs(prompt="foobar")) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 55d3aaec271f..305a0daca04a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -258,11 +258,13 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: """ # Determine which kwargs can be leveraged for the input processor # and drop + warn for kwargs that are unimplemented. + # NOTE: we don't allow override values for ctx/inputs, since doing + # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) processor_kwargs = self._get_allowed_kwarg_overrides( callable=processor, overrides=model_config.processor_kwargs, - ) + immutable_kwargs=("ctx", "inputs")) return functools.partial(self._process_input, model_config=model_config, processor=processor, @@ -293,6 +295,7 @@ def _get_allowed_kwarg_overrides( self, callable: Callable, overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], ) -> Dict[str, Any]: """ Given a callable processor, determine which kwarg overrides provided @@ -302,6 +305,7 @@ def _get_allowed_kwarg_overrides( Args: processor: Callable processor which takes 0 or more kwargs. model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. Returns: Dictionary containing the processor kwargs to be wrapped when @@ -309,6 +313,15 @@ def _get_allowed_kwarg_overrides( """ if not isinstance(overrides, dict): return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) # Drop any processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. @@ -324,4 +337,5 @@ def _get_allowed_kwarg_overrides( logger.warning( "The following kwarg overrides are not implemented " "by the input processor and will be dropped: %s", dropped_keys) + return filtered_overrides From 1472d0438edc6ddebfc7fc8991c8504598d49718 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 17:08:42 -0400 Subject: [PATCH 05/35] Add kwarg override for processor to dummy data factories Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 305a0daca04a..37ded8edd694 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -74,12 +74,16 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], + **processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. Note: :data:`InputProcessor` is not applied to the dummy data. + The processor_kwargs are overrides provided at initialization + time to values in the config whose values may affect the number + of tokens per instance. """ ... @@ -185,10 +189,17 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + # Check to see if this model expects additional processor kwargs; + # even though the processor isn't used on the dummy data, values + # passed to it that override the config may have implications on + # the number dummy data, e.g., the number of image tokens per instance. + df_kwargs = self._get_dummy_factory_processor_kwargs( + model_config, dummy_factory) seq_data, mm_data = dummy_factory( InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), + **df_kwargs, ) # Having more tokens is over-conservative but otherwise fine @@ -207,6 +218,21 @@ def dummy_data_for_profiling( return seq_data, mm_data + def _get_dummy_factory_processor_kwargs( + self, model_config: "ModelConfig", + dummy_factory: Callable) -> Dict[str, Any]: + # Dummy factory takes no additional kwargs; presumably this means that + # image processor kwargs have either not been implemented, or they have + # no affect on the token counts. + if len(inspect.signature(dummy_factory).parameters) < 4: + return {} + # Otherwise we may have overrides; filter them in the + # same way we filter the input processor overrides + return self._get_allowed_kwarg_overrides( + callable=dummy_factory, + overrides=model_config.processor_kwargs, + immutable_kwargs=("ctx", "seq_len", "mm_counts")) + def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: """The default input processor is a no-op.""" From f10601fae34d71dbae12cecbec5bd88361aae392 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:13:12 -0400 Subject: [PATCH 06/35] Add kwarg override forr processor to max token calc Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 102 ++++++++++++++++++++-------------------- vllm/multimodal/base.py | 8 +++- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 37ded8edd694..7393a883778d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -28,6 +28,55 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +def get_allowed_kwarg_overrides( + callable: Callable, + overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], +) -> Dict[str, Any]: + """ + Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. + """ + if not isinstance(overrides, dict): + return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + + return filtered_overrides + + @dataclass(frozen=True) class InputContext: """ @@ -228,7 +277,7 @@ def _get_dummy_factory_processor_kwargs( return {} # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides - return self._get_allowed_kwarg_overrides( + return get_allowed_kwarg_overrides( callable=dummy_factory, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "seq_len", "mm_counts")) @@ -287,7 +336,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = self._get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_overrides( callable=processor, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "inputs")) @@ -316,52 +365,3 @@ def _get_model_input_processor(self, processor = self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) return processor - - def _get_allowed_kwarg_overrides( - self, - callable: Callable, - overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], - ) -> Dict[str, Any]: - """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. - - Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. - - Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. - """ - if not isinstance(overrides, dict): - return {} - - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] - - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) - # Drop any processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs - } - - # If anything is dropped, log a warning - dropped_keys = set(overrides) - set(filtered_overrides) - if dropped_keys: - logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) - - return filtered_overrides diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 032964fe0ac4..0623de1d523d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -13,6 +13,7 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext +from vllm.inputs.registry import get_allowed_kwarg_overrides from vllm.logger import init_logger from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -333,7 +334,12 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - max_mm_tokens = max_mm_tokens(InputContext(model_config)) + processor_kwargs = get_allowed_kwarg_overrides( + callable=max_mm_tokens, + overrides=model_config.processor_kwargs, + immutable_kwargs=("ctx",)) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) From 429097aa71c2cd5984a4e26df895e88863c0a36b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:18:04 -0400 Subject: [PATCH 07/35] Move kwarg only override func to utils Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 54 +++-------------------------------------- vllm/multimodal/base.py | 8 +++--- vllm/utils.py | 50 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 55 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7393a883778d..36682adaed40 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,6 +11,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.utils import get_allowed_kwarg_only_overrides from .data import LLMInputs @@ -28,55 +29,6 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" -def get_allowed_kwarg_overrides( - callable: Callable, - overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], -) -> Dict[str, Any]: - """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. - - Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. - - Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. - """ - if not isinstance(overrides, dict): - return {} - - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] - - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) - # Drop any processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs - } - - # If anything is dropped, log a warning - dropped_keys = set(overrides) - set(filtered_overrides) - if dropped_keys: - logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) - - return filtered_overrides - - @dataclass(frozen=True) class InputContext: """ @@ -277,7 +229,7 @@ def _get_dummy_factory_processor_kwargs( return {} # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides - return get_allowed_kwarg_overrides( + return get_allowed_kwarg_only_overrides( callable=dummy_factory, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "seq_len", "mm_counts")) @@ -336,7 +288,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_only_overrides( callable=processor, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "inputs")) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 0623de1d523d..7af50d4a55aa 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -13,9 +13,9 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext -from vllm.inputs.registry import get_allowed_kwarg_overrides from vllm.logger import init_logger -from vllm.utils import JSONTree, is_list_of, json_map_leaves +from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, + json_map_leaves) logger = init_logger(__name__) @@ -334,10 +334,10 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - processor_kwargs = get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx",)) + immutable_kwargs=("ctx", )) max_mm_tokens = max_mm_tokens(InputContext(model_config), **processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec783..45db573809bb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,6 +4,7 @@ import datetime import enum import gc +import inspect import os import random import socket @@ -1237,6 +1238,55 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def get_allowed_kwarg_only_overrides( + callable: Callable, + overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], +) -> Dict[str, Any]: + """ + Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. + """ + if not isinstance(overrides, dict): + return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + + return filtered_overrides + + # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. From 159cfc26d5a2153e4ef1b450611b99475d83a19c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:41:21 -0400 Subject: [PATCH 08/35] Force processor kwargs to be keyword-only Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 2 ++ vllm/inputs/registry.py | 8 ++----- vllm/multimodal/base.py | 3 +-- vllm/utils.py | 38 ++++++++++++++---------------- 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index 6a9f88be50b4..c38bf9078356 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -18,8 +18,10 @@ def processor_mock(): """Patches the internal model input processor with an override callable.""" + # NOTE: processor kwargs must be keyword-only. def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, + *, num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return # type validation, and just return the value of the kwarg that we diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 36682adaed40..7d669c311591 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -230,9 +230,7 @@ def _get_dummy_factory_processor_kwargs( # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides return get_allowed_kwarg_only_overrides( - callable=dummy_factory, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", "seq_len", "mm_counts")) + callable=dummy_factory, overrides=model_config.processor_kwargs) def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: @@ -289,9 +287,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", "inputs")) + callable=processor, overrides=model_config.processor_kwargs) return functools.partial(self._process_input, model_config=model_config, processor=processor, diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7af50d4a55aa..b0118c71c26a 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -336,8 +336,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: if callable(max_mm_tokens): processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", )) + overrides=model_config.processor_kwargs) max_mm_tokens = max_mm_tokens(InputContext(model_config), **processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index 45db573809bb..22c6804246a5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1241,48 +1241,46 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def get_allowed_kwarg_only_overrides( callable: Callable, overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], ) -> Dict[str, Any]: """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. + callable: Callable which takes 0 or more keyword only arguments. + overrides: Potential overrides to be used when invoking the callable. Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. """ if not isinstance(overrides, dict): return {} - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] + allowed_override_names = [ + name for name, param in inspect.signature(callable).parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) # Drop any processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs + if kwarg_name in allowed_override_names } # If anything is dropped, log a warning dropped_keys = set(overrides) - set(filtered_overrides) if dropped_keys: logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) return filtered_overrides From af919301fc9eec8c6962394907d2abc9971d2967 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 03:05:04 -0400 Subject: [PATCH 09/35] Pass unfiltered processor kwargs to default mapper Signed-off-by: Alex-Brooks --- vllm/multimodal/image.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 6cdde949bc2b..137a574c7d1a 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Any, Dict import torch from PIL import Image @@ -22,10 +23,12 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig): + def _get_hf_image_processor(self, model_config: ModelConfig, + processor_kwargs: Dict[str, Any]): return cached_get_image_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **processor_kwargs) def _default_input_mapper( self, @@ -36,7 +39,13 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor(model_config) + processor_kwargs = ({} if model_config.processor_kwargs is None + else model_config.processor_kwargs) + + image_processor = self._get_hf_image_processor( + model_config, + processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") From 9adad10623580ad048ad17ae236d06b60c668969 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 03:57:25 -0400 Subject: [PATCH 10/35] Add hack for mapper preprocessor kwargs Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 12 +++++------- vllm/multimodal/base.py | 5 ++++- vllm/multimodal/registry.py | 8 ++++++++ vllm/utils.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7d669c311591..d422ddedd38a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -194,14 +194,12 @@ def dummy_data_for_profiling( # even though the processor isn't used on the dummy data, values # passed to it that override the config may have implications on # the number dummy data, e.g., the number of image tokens per instance. - df_kwargs = self._get_dummy_factory_processor_kwargs( + processor_kwargs = self._get_dummy_factory_processor_kwargs( model_config, dummy_factory) - seq_data, mm_data = dummy_factory( - InputContext(model_config), - seq_len, - _MultiModalCounts(mm_counts), - **df_kwargs, - ) + + seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index b0118c71c26a..d0d11e6cbda7 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -257,11 +257,14 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.processor_kwargs) + if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data) + return mapper(InputContext(model_config), data, **processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 745fc715caf4..f1c56226e044 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,6 +138,14 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ + # TODO - there is a bit of weirdness here in the way mapper handles + # the args, because for the HF one, we pass processor_kwargs at init + # time and don't need them at func time, for the function's we are + # wrapping in processor like interfaces, we pass them at the time + # of invocation. + # + # Currently it works, but warns when the default processor is used, + # which is bad. return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( diff --git a/vllm/utils.py b/vllm/utils.py index 22c6804246a5..e1b8ccfd6aad 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1239,7 +1239,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def get_allowed_kwarg_only_overrides( - callable: Callable, + callable: Optional[Callable], overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: """ @@ -1259,7 +1259,7 @@ def get_allowed_kwarg_only_overrides( to overwrite one or more keyword only arguments when invoking the callable. """ - if not isinstance(overrides, dict): + if not overrides or not callable: return {} allowed_override_names = [ From 9f7aed8e7e6db89146668c199a8b3926834e6797 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 18:38:18 -0400 Subject: [PATCH 11/35] Simplify dummy data processor kwarg & add tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 167 +++++++++++++++++++---------- vllm/inputs/registry.py | 22 +--- 2 files changed, 114 insertions(+), 75 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index c38bf9078356..cc56f26795a0 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,3 +1,5 @@ +from array import array +from typing import Mapping from unittest.mock import patch import pytest @@ -6,19 +8,35 @@ from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry +from vllm.multimodal import MultiModalRegistry +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + + + DUMMY_MODEL_ID = "facebook/opt-125m" -# For processor kwargs - we test overrides by defining a callable with a -# default for the `num_crops`, then override the value through the processor -# kwargs +# For processor_kwargs - we test overrides by defining mocks for each place +# it is used, and ensuring that we can pass processor kwargs an override value +# to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 +def get_model_config(processor_kwargs=None): + """Creates a handle to a model config, which may have processor kwargs.""" + # NOTE - values / architecture don't matter too much here since we patch + # the return values for stuff like the input processor anyway. + return ModelConfig(DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs) +# Mocks for all of the places that we use the processor_kwargs +# to override values in different callables @pytest.fixture -def processor_mock(): +def use_processor_mock(): """Patches the internal model input processor with an override callable.""" - - # NOTE: processor kwargs must be keyword-only. def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, *, @@ -32,76 +50,111 @@ def custom_processor(ctx: InputContext, return_value=custom_processor): yield - -def get_model_config(processor_kwargs=None): - """Creates a handle to a model config, which may have processor kwargs.""" - # NOTE - values / architecture don't matter too much here since we patch - # the return values for stuff like the input processor anyway. - return ModelConfig(DUMMY_MODEL_ID, - DUMMY_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs) +@pytest.fixture +def use_dummy_data_mock(): + """Patches the internal model input processor with an override callable.""" + def custom_dummy_data_factory(self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops=DEFAULT_NUM_CROPS): + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + return seq_data, None + + with patch("vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): + yield +### Test for default processor logic & processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() model_config = get_model_config() processor = dummy_registry.create_input_processor(model_config) - proc_inputs = LLMInputs(prompt="foobar") + proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) - # We should get the same object back since this is a no-op by default assert proc_inputs is proc_outputs - -def test_processor_default_kwargs(processor_mock): - """Ensure we can call a processor that has extra kwargs & no overrides.""" +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_processor_default_kwargs(use_processor_mock, num_crops): + """Ensure that we can override processor kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config() + # If we have a value for num_crops, pass the override value and make + # sure we get that value as a return-value from out mock processor, + # otherwise fall back to the default value + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops + model_config = get_model_config(processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) - # The patched fixture patches the processor to return the value of - # num_crops in the processor call, which should be 4 by default. - num_crops_val = processor(LLMInputs(prompt="foobar")) - assert num_crops_val == DEFAULT_NUM_CROPS - -def test_processor_default_kwargs_with_override(processor_mock): - """Ensure we can call a processor that has extra kwargs & no overrides.""" - dummy_registry = InputRegistry() - # Create processor_kwargs to override the value used - # for num_crops in the patched processor callable - model_config = get_model_config( - processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE}) - processor = dummy_registry.create_input_processor(model_config) - num_crops_val = processor(LLMInputs(prompt="foobar")) - # Since the patched processor is an echo, we should get the - # override value we passed to processor_kwargs instead. - assert num_crops_val == NUM_CROPS_OVERRIDE + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == expected_num_crops -def test_processor_with_sad_kwarg_overrides(processor_mock): - """Ensure that processor kwargs that are unused do not fail.""" +@pytest.mark.parametrize("processor_kwargs", + [ + {"does_not_exist": 100}, # Not part of the signature + {"ctx": "something bad"} # Part of the signature, not keyword only + ] +) +def test_processor_with_sad_kwarg_overrides(use_processor_mock, + processor_kwargs): + """Ensure invalid processor_kwargs can't be used in the input processor.""" dummy_registry = InputRegistry() - # Since the processor does not take `does_not_exist` as an arg, - # it will be filtered, then warn + drop it from the callable - # to prevent the processor from failing. - model_config = get_model_config(processor_kwargs={"does_not_exist": 100}, ) + + model_config = get_model_config(processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) - num_crops_val = processor(LLMInputs(prompt="foobar")) + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS -def test_processor_kwargs_cannot_clobber_reserved_kwargs(processor_mock): - """Ensure that special kwargs cannot be overridden.""" +### Test overrides for the dummy data +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs={"ctx": - "something bad"}, ) - processor = dummy_registry.create_input_processor(model_config) - # It's good enough to make sure this is callable, because if we had - # an override pushed through, we'd run into issues with multiple - # values provided for a single argument - processor(LLMInputs(prompt="foobar")) + model_config = get_model_config( + processor_kwargs=processor_kwargs, + ) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + model_config, + seq_len=-1, + mm_registry=mm_registry + ) + assert len(seq_data.prompt_token_ids) == expected_seq_count + + +@pytest.mark.parametrize("processor_kwargs", + [ + {"does_not_exist": 100}, # Not part of the signature + {"ctx": "something bad"} # Part of the signature, not keyword only + ] +) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): + """Ensure that dummy_data kwargs that are unused do not fail.""" + dummy_registry = InputRegistry() + model_config = get_model_config( + processor_kwargs=processor_kwargs, + ) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + model_config, + seq_len=-1, + mm_registry=mm_registry + ) + assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index d422ddedd38a..7b74c24d7344 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -190,12 +190,10 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - # Check to see if this model expects additional processor kwargs; - # even though the processor isn't used on the dummy data, values - # passed to it that override the config may have implications on - # the number dummy data, e.g., the number of image tokens per instance. - processor_kwargs = self._get_dummy_factory_processor_kwargs( - model_config, dummy_factory) + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=dummy_factory, + overrides=model_config.processor_kwargs + ) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -217,18 +215,6 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _get_dummy_factory_processor_kwargs( - self, model_config: "ModelConfig", - dummy_factory: Callable) -> Dict[str, Any]: - # Dummy factory takes no additional kwargs; presumably this means that - # image processor kwargs have either not been implemented, or they have - # no affect on the token counts. - if len(inspect.signature(dummy_factory).parameters) < 4: - return {} - # Otherwise we may have overrides; filter them in the - # same way we filter the input processor overrides - return get_allowed_kwarg_only_overrides( - callable=dummy_factory, overrides=model_config.processor_kwargs) def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: From ff59e44c351ed34f4c2e14ea136752bae7d93856 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 19:58:05 -0400 Subject: [PATCH 12/35] Add tests for max multimodal token kwarg overrides Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 146 +++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 31 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index cc56f26795a0..f94a4f5abce6 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -7,19 +7,22 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry - from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.model_executor.models.phi3v import Phi3VForCausalLM - - +# Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" +# Used for tests that need a multimodal model +MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + # For processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 + def get_model_config(processor_kwargs=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch @@ -32,11 +35,13 @@ def get_model_config(processor_kwargs=None): seed=0, processor_kwargs=processor_kwargs) + # Mocks for all of the places that we use the processor_kwargs # to override values in different callables @pytest.fixture def use_processor_mock(): """Patches the internal model input processor with an override callable.""" + def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, *, @@ -50,23 +55,31 @@ def custom_processor(ctx: InputContext, return_value=custom_processor): yield + @pytest.fixture def use_dummy_data_mock(): """Patches the internal model input processor with an override callable.""" + def custom_dummy_data_factory(self, ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], *, num_crops=DEFAULT_NUM_CROPS): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) return seq_data, None - with patch("vllm.inputs.registry.InputRegistry._default_dummy_data_factory", - custom_dummy_data_factory): + with patch( + "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): yield +# lambda whose signature matches max token calcs + extra kwargs +get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops + + ### Test for default processor logic & processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" @@ -77,6 +90,7 @@ def test_default_processor_is_a_noop(): proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs + @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_processor_default_kwargs(use_processor_mock, num_crops): """Ensure that we can override processor kwargs.""" @@ -93,12 +107,16 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): assert num_crops_val == expected_num_crops -@pytest.mark.parametrize("processor_kwargs", +@pytest.mark.parametrize( + "processor_kwargs", [ - {"does_not_exist": 100}, # Not part of the signature - {"ctx": "something bad"} # Part of the signature, not keyword only - ] -) + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): """Ensure invalid processor_kwargs can't be used in the input processor.""" @@ -117,9 +135,7 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config( - processor_kwargs=processor_kwargs, - ) + model_config = get_model_config(processor_kwargs=processor_kwargs, ) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -127,25 +143,25 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, - seq_len=-1, - mm_registry=mm_registry - ) + model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count -@pytest.mark.parametrize("processor_kwargs", +@pytest.mark.parametrize( + "processor_kwargs", [ - {"does_not_exist": 100}, # Not part of the signature - {"ctx": "something bad"} # Part of the signature, not keyword only - ] -) -def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, + processor_kwargs): """Ensure that dummy_data kwargs that are unused do not fail.""" dummy_registry = InputRegistry() - model_config = get_model_config( - processor_kwargs=processor_kwargs, - ) + model_config = get_model_config(processor_kwargs=processor_kwargs, ) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -153,8 +169,76 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwar # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, - seq_len=-1, - mm_registry=mm_registry - ) + model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + + +### Test overrides for the max token count per multimodal instance +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_max_tokens_kwarg_overrides(num_crops): + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {Phi3VForCausalLM: get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + model_config) + + assert expected_seq_count == max_multimodal_tokens + + +@pytest.mark.parametrize( + "processor_kwargs", + [ + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # Similar before, but since these kwargs get filtered, + # we always get our default value back. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {Phi3VForCausalLM: get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + model_config) + + assert max_multimodal_tokens == DEFAULT_NUM_CROPS From 6b264547e55c943faf8924955d5e9635d6189d1d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 20:52:15 -0400 Subject: [PATCH 13/35] Format registry Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7b74c24d7344..08f516e17aef 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,5 +1,4 @@ import functools -import inspect from array import array from collections import UserDict from dataclasses import dataclass @@ -191,9 +190,7 @@ def dummy_data_for_profiling( mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) processor_kwargs = get_allowed_kwarg_only_overrides( - callable=dummy_factory, - overrides=model_config.processor_kwargs - ) + callable=dummy_factory, overrides=model_config.processor_kwargs) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -215,7 +212,6 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: """The default input processor is a no-op.""" From 0e2d53d9baacb6dc0c77e1299bbdd77d4c1c1257 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:06:34 -0400 Subject: [PATCH 14/35] Fix default mapper comparison Signed-off-by: Alex-Brooks --- vllm/multimodal/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index d0d11e6cbda7..06cbb528c34b 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -257,8 +257,14 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=mapper, overrides=model_config.processor_kwargs) + # Only get processor kwargs at mapping time if we are not using the + # input mapper; no overrides are used on the default here because they + # should be passed to the huggingface resource at initialization time. + if mapper != self._default_input_mapper: + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.processor_kwargs) + else: + processor_kwargs = {} if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " From 5a3341bd75c3d1ef8b3d696736d34d927085b7d2 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:07:02 -0400 Subject: [PATCH 15/35] Move kwarg filtering into hf processor getter Signed-off-by: Alex-Brooks --- vllm/multimodal/image.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 137a574c7d1a..c2657a112173 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import Any, Dict import torch from PIL import Image @@ -23,8 +22,11 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig, - processor_kwargs: Dict[str, Any]): + def _get_hf_image_processor(self, model_config: ModelConfig): + processor_kwargs = ({} if model_config.processor_kwargs is None else + model_config.processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -39,13 +41,8 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - processor_kwargs = ({} if model_config.processor_kwargs is None - else model_config.processor_kwargs) + image_processor = self._get_hf_image_processor(model_config) - image_processor = self._get_hf_image_processor( - model_config, - processor_kwargs, - ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") From 3e1fe54acc605b7d4f62a5e0235dcb147881e8f5 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:08:01 -0400 Subject: [PATCH 16/35] Enable processor_kwargs in video processor Signed-off-by: Alex-Brooks --- vllm/multimodal/video.py | 7 ++++++- vllm/transformers_utils/image_processor.py | 9 ++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 4401d1315792..aff920977662 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -37,9 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): + processor_kwargs = ({} if model_config.processor_kwargs is None else + model_config.processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 4cffac3724ba..61b338972e0f 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -3,7 +3,9 @@ def get_video_processor( processor_name: str, + *args, trust_remote_code: bool = False, + **kwargs, ): """ Gets a processor for the given model name via HuggingFace. @@ -11,7 +13,12 @@ def get_video_processor( from transformers import AutoProcessor try: - processor = AutoProcessor.from_pretrained(processor_name) + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) video_processor = processor.video_processor except ValueError as e: From feccfd7c7575b121ad18b6bcec4d27a003b80a46 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:08:38 -0400 Subject: [PATCH 17/35] Add tests for mapper processor_kwargs Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 113 ++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index f94a4f5abce6..c86fa5d3c7e4 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -3,13 +3,14 @@ from unittest.mock import patch import pytest +import torch from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry +from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData -from vllm.model_executor.models.phi3v import Phi3VForCausalLM # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" @@ -78,6 +79,9 @@ def custom_dummy_data_factory(self, # lambda whose signature matches max token calcs + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops +custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { + "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +} ### Test for default processor logic & processor_kwargs wrapping @@ -242,3 +246,110 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): model_config) assert max_multimodal_tokens == DEFAULT_NUM_CROPS + + +### Test overrides for the mapper +@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) +def test_default_mapper_with_processer_kwargs(image_assets, num_crops): + """Ensure that the mapper processor kwargs can fall back to HF models.""" + # NOTE - we don't validate bad inputs for the default mapper, because it's + # through the automodel interface in transformers, so we can't easily + # inspect what kwargs are or are not allowed. + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] + assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_custom_mapper_kwarg_overrides(image_assets, num_crops): + """Ensure that custom mappers can consume processor_kwargs.""" + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {Phi3VForCausalLM: custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 + + +@pytest.mark.parametrize( + "processor_kwargs", + [ + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_custom_mapper_with_sad_kwarg_overrides(image_assets, + processor_kwargs): + """Ensure that custom mappers can filter out invalid processor_kwargs.""" + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {Phi3VForCausalLM: custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From 3ada64de23250759389a7d49000ae87048b1efae Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:21:23 -0400 Subject: [PATCH 18/35] Update mapper not on multimodal processor kwargs Signed-off-by: Alex-Brooks --- vllm/multimodal/registry.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f1c56226e044..3940e1671b57 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,14 +138,15 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ - # TODO - there is a bit of weirdness here in the way mapper handles - # the args, because for the HF one, we pass processor_kwargs at init - # time and don't need them at func time, for the function's we are - # wrapping in processor like interfaces, we pass them at the time - # of invocation. + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. # - # Currently it works, but warns when the default processor is used, - # which is bad. + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( From 58dcc63ce4b11fd38b212e9fc65d893da6e8a336 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:36:53 -0400 Subject: [PATCH 19/35] processor kwarg test cleanup Signed-off-by: Alex-Brooks --- ..._processor.py => test_processor_kwargs.py} | 152 +++++++----------- 1 file changed, 60 insertions(+), 92 deletions(-) rename tests/multimodal/{test_processor.py => test_processor_kwargs.py} (74%) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor_kwargs.py similarity index 74% rename from tests/multimodal/test_processor.py rename to tests/multimodal/test_processor_kwargs.py index c86fa5d3c7e4..1c84cd265e26 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -24,17 +24,18 @@ NUM_CROPS_OVERRIDE = 16 -def get_model_config(processor_kwargs=None): +def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch # the return values for stuff like the input processor anyway. - return ModelConfig(DUMMY_MODEL_ID, - DUMMY_MODEL_ID, + return ModelConfig(model_name, + model_name, tokenizer_mode="auto", - trust_remote_code=False, + trust_remote_code=trust_remote_code, dtype="float16", seed=0, - processor_kwargs=processor_kwargs) + processor_kwargs=processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt) # Mocks for all of the places that we use the processor_kwargs @@ -77,7 +78,7 @@ def custom_dummy_data_factory(self, yield -# lambda whose signature matches max token calcs + extra kwargs +# lambda whose signature matches max token calcs + extra kwargs & mapper respectively get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) @@ -88,7 +89,7 @@ def custom_dummy_data_factory(self, def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() - model_config = get_model_config() + model_config = get_model_config(DUMMY_MODEL_ID) processor = dummy_registry.create_input_processor(model_config) proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) @@ -97,14 +98,15 @@ def test_default_processor_is_a_noop(): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_processor_default_kwargs(use_processor_mock, num_crops): - """Ensure that we can override processor kwargs.""" + """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(processor_kwargs=processor_kwargs) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -114,19 +116,18 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): - """Ensure invalid processor_kwargs can't be used in the input processor.""" + """Ensure that input processors filter out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -136,10 +137,12 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, ### Test overrides for the dummy data @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + """Ensure dummy data factories can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs, ) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -154,18 +157,17 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): - """Ensure that dummy_data kwargs that are unused do not fail.""" + """Ensure that dummy data factory filters out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs, ) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -180,19 +182,14 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, ### Test overrides for the max token count per multimodal instance @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_max_tokens_kwarg_overrides(num_crops): + """Ensure max token calcs can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -213,24 +210,17 @@ def test_max_tokens_kwarg_overrides(num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + """Ensure that max token calcs filters out invalid processor_kwargs.""" + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -255,16 +245,10 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -279,20 +263,13 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_custom_mapper_kwarg_overrides(image_assets, num_crops): - """Ensure that custom mappers can consume processor_kwargs.""" + """Ensure custom mappers can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -315,27 +292,18 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, processor_kwargs): - """Ensure that custom mappers can filter out invalid processor_kwargs.""" - - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + """Ensure that custom mappers filters out invalid processor_kwargs.""" + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) From 1cee21558d1eff37bfa81cda3825de262a24d07a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 16:59:16 -0400 Subject: [PATCH 20/35] Move context builder to test utils Signed-off-by: Alex-Brooks --- .../decoder_only/vision_language/test_qwen.py | 29 +---------------- tests/models/utils.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b760..638fb68b8f87 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,14 +5,13 @@ import torch from PIL.Image import Image -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, VllmRunner, _ImageAssets) -from ...utils import check_logprobs_close +from ...utils import build_model_context, check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component @@ -42,32 +41,6 @@ IMG_SIZE = 448 -def build_model_context(model_name: str, - tokenizer_name: Optional[str] = None, - trust_remote_code: bool = False): - """Creates an InputContext for a given model. - - Args: - model_name: Name of the model being considered. - tokenizer_name: Name of the tokenizer being considered. - trust_remote_code: Whether or not to allow loading remote code. - - Returns: - InputContext for the model being considered. - """ - if tokenizer_name is None: - tokenizer_name = model_name - model_config = ModelConfig( - model_name, - tokenizer_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float32", - seed=0, - ) - return InputContext(model_config) - - @pytest.fixture() def input_mapper_for_qwen(): # Lazy import to avoid initializing CUDA during test collection diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e31a1d6eefe..0c3e876dd6cd 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,6 +1,8 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union +from vllm.config import ModelConfig +from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -240,3 +242,33 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False, + processor_kwargs: Optional[Dict] = None): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + processor_kwargs: optional processor kwargs for to be leveraged + in the input processor, mapper, dummy data creation, etc. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + processor_kwargs=processor_kwargs, + ) + return InputContext(model_config) From d5f9efa94a80e2a4751a69f109027df0334789c7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:55:42 -0400 Subject: [PATCH 21/35] Use common context builder in processor kwarg tests Signed-off-by: Alex-Brooks --- tests/models/utils.py | 5 +- tests/multimodal/test_processor_kwargs.py | 144 +++++++++++----------- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 0c3e876dd6cd..77a7e054bf68 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -247,7 +247,8 @@ def check_logprobs_close( def build_model_context(model_name: str, tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, - processor_kwargs: Optional[Dict] = None): + processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. Args: @@ -256,6 +257,7 @@ def build_model_context(model_name: str, trust_remote_code: Whether or not to allow loading remote code. processor_kwargs: optional processor kwargs for to be leveraged in the input processor, mapper, dummy data creation, etc. + limit_mm_per_prompt: Multimodal limits. Returns: InputContext for the model being considered. @@ -270,5 +272,6 @@ def build_model_context(model_name: str, dtype="float32", seed=0, processor_kwargs=processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 1c84cd265e26..35df3fe1492e 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,13 +5,14 @@ import pytest import torch -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from ..models.utils import build_model_context + # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model @@ -24,20 +25,6 @@ NUM_CROPS_OVERRIDE = 16 -def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None): - """Creates a handle to a model config, which may have processor kwargs.""" - # NOTE - values / architecture don't matter too much here since we patch - # the return values for stuff like the input processor anyway. - return ModelConfig(model_name, - model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt=limit_mm_per_prompt) - - # Mocks for all of the places that we use the processor_kwargs # to override values in different callables @pytest.fixture @@ -78,7 +65,7 @@ def custom_dummy_data_factory(self, yield -# lambda whose signature matches max token calcs + extra kwargs & mapper respectively +# lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) @@ -89,8 +76,8 @@ def custom_dummy_data_factory(self, def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID) + processor = dummy_registry.create_input_processor(ctx.model_config) proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -105,9 +92,9 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): # otherwise fall back to the default value processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == expected_num_crops @@ -117,19 +104,22 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): """Ensure that input processors filter out invalid processor_kwargs.""" dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - - processor = dummy_registry.create_input_processor(model_config) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS @@ -141,16 +131,16 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count @@ -158,24 +148,28 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): """Ensure that dummy data factory filters out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS @@ -186,13 +180,13 @@ def test_max_tokens_kwarg_overrides(num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -202,7 +196,7 @@ def test_max_tokens_kwarg_overrides(num_crops): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert expected_seq_count == max_multimodal_tokens @@ -211,19 +205,23 @@ def test_max_tokens_kwarg_overrides(num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): """Ensure that max token calcs filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Similar before, but since these kwargs get filtered, # we always get our default value back. @@ -233,7 +231,7 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert max_multimodal_tokens == DEFAULT_NUM_CROPS @@ -245,18 +243,18 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) image = image_assets[0].pil_image mm_inputs = {"image": image} - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 @@ -266,13 +264,13 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): """Ensure custom mappers can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -284,7 +282,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 @@ -293,20 +291,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, processor_kwargs): """Ensure that custom mappers filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -318,6 +320,6 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From b5d434b5c9c906ded90f56b10372c434e94751dc Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:05:26 -0600 Subject: [PATCH 22/35] Update vllm/entrypoints/llm.py Co-authored-by: Cyrus Leung --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6304851233ce..d27ea214ff37 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,7 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, - processor_kwargs=None, + processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' From a0963014c73941bf2fa75f059dd96f2c68133731 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:06:41 -0600 Subject: [PATCH 23/35] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 08f516e17aef..e9d528ad9067 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -217,7 +217,7 @@ def _default_input_processor(self, ctx: InputContext, """The default input processor is a no-op.""" return inputs - def register_input_processor(self, processor: InputProcessor) -> Callable: + def register_input_processor(self, processor: InputProcessor): """ Register an input processor to a model class. From 79962e02e55d70c2fa4672c863454e898a1a71a3 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:06:54 -0600 Subject: [PATCH 24/35] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index e9d528ad9067..c0043f35711b 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -256,7 +256,7 @@ def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", return processor(InputContext(model_config), inputs, **processor_kwargs) - def create_input_processor(self, model_config: "ModelConfig") -> Callable: + def create_input_processor(self, model_config: "ModelConfig"): """ Create an input processor (see :meth:`_process_input`) for a specific model. From 2cb1f72c0b8c53d6c7278959113a0a6a8ead2b05 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:07:10 -0600 Subject: [PATCH 25/35] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index c0043f35711b..410ae3021f4e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -274,7 +274,7 @@ def create_input_processor(self, model_config: "ModelConfig"): **processor_kwargs) def _get_model_input_processor(self, - model_config: "ModelConfig") -> Callable: + model_config: "ModelConfig"): """ Grabs the input processor for the provided model. From 37eb5324855268587217011ebbd6e2410c57d721 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:07:25 -0600 Subject: [PATCH 26/35] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 410ae3021f4e..a5b5cddc6c6c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -81,7 +81,8 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The processor_kwargs are overrides provided at initialization + + The :code:`processor_kwargs` are overrides provided at initialization time to values in the config whose values may affect the number of tokens per instance. """ From a4c7c3dea4753684d08a42bc8a97476518a95aa0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:08:31 -0600 Subject: [PATCH 27/35] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index a5b5cddc6c6c..caaaa0cfc713 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -243,7 +243,7 @@ def wrapper(model_cls: N) -> N: return wrapper def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", - processor: Callable, **processor_kwargs) -> LLMInputs: + processor: InputProcessor, **processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in From 36dd2cba7cb5150ff53cdb1560e3c80b01eb3712 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 02:12:25 -0400 Subject: [PATCH 28/35] Fix formatting Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index caaaa0cfc713..2a4a2250aba7 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -82,9 +82,9 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The :code:`processor_kwargs` are overrides provided at initialization - time to values in the config whose values may affect the number - of tokens per instance. + The :code:`processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. """ ... @@ -243,7 +243,8 @@ def wrapper(model_cls: N) -> N: return wrapper def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", - processor: InputProcessor, **processor_kwargs: Any) -> LLMInputs: + processor: InputProcessor, + **processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in @@ -274,8 +275,7 @@ def create_input_processor(self, model_config: "ModelConfig"): processor=processor, **processor_kwargs) - def _get_model_input_processor(self, - model_config: "ModelConfig"): + def _get_model_input_processor(self, model_config: "ModelConfig"): """ Grabs the input processor for the provided model. From f95c86f7798b2fa6925a7bc993fb5a39f662f5ee Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 04:29:21 -0400 Subject: [PATCH 29/35] Rename processor kwargs to mm processor kwargs Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 6 +- tests/models/utils.py | 6 +- tests/multimodal/test_processor_kwargs.py | 74 +++++++++++++---------- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 10 +-- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 4 +- vllm/inputs/registry.py | 20 +++--- vllm/multimodal/base.py | 14 ++--- vllm/multimodal/image.py | 6 +- vllm/multimodal/video.py | 6 +- vllm/utils.py | 2 +- 12 files changed, 83 insertions(+), 75 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index fabf37aa2a68..360ac1bfbad9 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -54,10 +54,10 @@ def test_bad_nullable_kvs(arg): } }), ]) -def test_processor_kwargs_prompt_parser(arg, expected): +def test_mm_processor_kwargs_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) else: - args = parser.parse_args(["--processor-kwargs", arg]) - assert args.processor_kwargs == expected + args = parser.parse_args(["--mm-processor-kwargs", arg]) + assert args.mm_processor_kwargs == expected diff --git a/tests/models/utils.py b/tests/models/utils.py index 77a7e054bf68..eb6254f18182 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -247,7 +247,7 @@ def check_logprobs_close( def build_model_context(model_name: str, tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, - processor_kwargs: Optional[Dict] = None, + mm_processor_kwargs: Optional[Dict] = None, limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. @@ -255,7 +255,7 @@ def build_model_context(model_name: str, model_name: Name of the model being considered. tokenizer_name: Name of the tokenizer being considered. trust_remote_code: Whether or not to allow loading remote code. - processor_kwargs: optional processor kwargs for to be leveraged + mm_processor_kwargs: optional processor kwargs for to be leveraged in the input processor, mapper, dummy data creation, etc. limit_mm_per_prompt: Multimodal limits. @@ -271,7 +271,7 @@ def build_model_context(model_name: str, trust_remote_code=trust_remote_code, dtype="float32", seed=0, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 35df3fe1492e..d7fa32a7f214 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -18,14 +18,14 @@ # Used for tests that need a multimodal model MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -# For processor_kwargs - we test overrides by defining mocks for each place +# For mm_processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 -# Mocks for all of the places that we use the processor_kwargs +# Mocks for all of the places that we use the mm_processor_kwargs # to override values in different callables @pytest.fixture def use_processor_mock(): @@ -72,7 +72,7 @@ def custom_dummy_data_factory(self, } -### Test for default processor logic & processor_kwargs wrapping +### Test for default processor logic & mm_processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() @@ -90,10 +90,12 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -101,7 +103,7 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -113,11 +115,11 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): } ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, - processor_kwargs): - """Ensure that input processors filter out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure that input processors filter out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -128,24 +130,26 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): """Ensure dummy data factories can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq - # len is solely dependent on the value of the processor_kwargs. + # len is solely dependent on the value of the mm_processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -157,17 +161,17 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): } ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, - processor_kwargs): - """Ensure that dummy data factory filters out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure the dummy data factory filters out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq - # len is solely dependent on the value of the processor_kwargs. + # len is solely dependent on the value of the mm_processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS @@ -177,19 +181,21 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_max_tokens_kwarg_overrides(num_crops): """Ensure max token calcs can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", @@ -202,7 +208,7 @@ def test_max_tokens_kwarg_overrides(num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -213,11 +219,11 @@ def test_max_tokens_kwarg_overrides(num_crops): "ctx": "something bad" } ]) -def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): - """Ensure that max token calcs filters out invalid processor_kwargs.""" +def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): + """Ensure that max token calcs filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() @@ -245,7 +251,7 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # inspect what kwargs are or are not allowed. ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs={"num_crops": num_crops}, + mm_processor_kwargs={"num_crops": num_crops}, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() @@ -262,18 +268,20 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_custom_mapper_kwarg_overrides(image_assets, num_crops): """Ensure custom mappers can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} @@ -288,7 +296,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -300,18 +308,18 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): } ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, - processor_kwargs): - """Ensure that custom mappers filters out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure that custom mappers filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} diff --git a/vllm/config.py b/vllm/config.py index 94552a22cc25..c30867565132 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,7 +122,7 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - processor_kwargs: Arguments to be forwarded to the model's processor, + mm_processor_kwargs: Arguments to be forwarded to the model's processor, e.g., tokenizer, image processor, or custom processor callable. """ @@ -153,7 +153,7 @@ def __init__(self, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, - processor_kwargs: Optional[Dict[str, Any]] = None) -> None: + mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -187,7 +187,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.processor_kwargs = processor_kwargs + self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca1f334de535..ca6034ddbe5c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,7 +175,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None - processor_kwargs: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -515,11 +515,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) parser.add_argument( - '--processor-kwargs', + '--mm-processor-kwargs', default=None, type=json.loads, - help=('Overrides for the model processor, e.g., tokenizer or ' - 'image processor. For example: {"num_crops": 4}.')) + help=('Overrides for the multimodal input mapping/processing,' + 'e.g., image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -829,7 +829,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, - processor_kwargs=self.processor_kwargs, + mm_processor_kwargs=self.mm_processor_kwargs, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a482cbbe2009..4d9696e464bc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -235,7 +235,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, processor_kwargs=%s)", + "use_async_output_proc=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -268,7 +268,7 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, - model_config.processor_kwargs, + model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d27ea214ff37..5dd02d0f9a12 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,7 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, - processor_kwargs: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' @@ -175,7 +175,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2a4a2250aba7..f6e53a08bb48 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -74,7 +74,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - **processor_kwargs: Any, + **mm_processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. @@ -82,7 +82,7 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The :code:`processor_kwargs` are overrides provided at + The :code:`mm_processor_kwargs` are overrides provided at initialization time to values in the config whose values may affect the number of tokens per instance. """ @@ -190,12 +190,12 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=dummy_factory, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=dummy_factory, overrides=model_config.mm_processor_kwargs) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), - **processor_kwargs) + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids @@ -244,7 +244,7 @@ def wrapper(model_cls: N) -> N: def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", processor: InputProcessor, - **processor_kwargs: Any) -> LLMInputs: + **mm_processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in @@ -256,7 +256,7 @@ def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", :ref:`input_processing_pipeline` """ return processor(InputContext(model_config), inputs, - **processor_kwargs) + **mm_processor_kwargs) def create_input_processor(self, model_config: "ModelConfig"): """ @@ -268,12 +268,12 @@ def create_input_processor(self, model_config: "ModelConfig"): # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=processor, overrides=model_config.mm_processor_kwargs) return functools.partial(self._process_input, model_config=model_config, processor=processor, - **processor_kwargs) + **mm_processor_kwargs) def _get_model_input_processor(self, model_config: "ModelConfig"): """ diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 06cbb528c34b..ee840caabe0b 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -261,16 +261,16 @@ def map_input(self, model_config: ModelConfig, # input mapper; no overrides are used on the default here because they # should be passed to the huggingface resource at initialization time. if mapper != self._default_input_mapper: - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=mapper, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.mm_processor_kwargs) else: - processor_kwargs = {} + mm_processor_kwargs = {} if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data, **processor_kwargs) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: @@ -343,11 +343,11 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - processor_kwargs = get_allowed_kwarg_only_overrides( + mm_processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, - overrides=model_config.processor_kwargs) + overrides=model_config.mm_processor_kwargs) max_mm_tokens = max_mm_tokens(InputContext(model_config), - **processor_kwargs) + **mm_processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index c2657a112173..d71e24d71f2e 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -23,14 +23,14 @@ def get_data_key(self) -> str: return "image" def _get_hf_image_processor(self, model_config: ModelConfig): - processor_kwargs = ({} if model_config.processor_kwargs is None else - model_config.processor_kwargs) + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) # We don't explicitly check kwarg overrides to the HF class # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, - **processor_kwargs) + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index aff920977662..75216df451b3 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -37,14 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): - processor_kwargs = ({} if model_config.processor_kwargs is None else - model_config.processor_kwargs) + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) # We don't explicitly check kwarg overrides to the HF class # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, - **processor_kwargs) + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/utils.py b/vllm/utils.py index e1b8ccfd6aad..3369a8672909 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1267,7 +1267,7 @@ def get_allowed_kwarg_only_overrides( if param.kind == inspect.Parameter.KEYWORD_ONLY ] - # Drop any processor_kwargs provided by the user that are + # Drop any mm_processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. filtered_overrides = { kwarg_name: val From 229604fc4c86fb1bbdd953a90b4201ded9ea544a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 22 Sep 2024 16:49:39 +0800 Subject: [PATCH 30/35] Update docstring --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c30867565132..fae2d44f174b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,8 +122,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - mm_processor_kwargs: Arguments to be forwarded to the model's processor, - e.g., tokenizer, image processor, or custom processor callable. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. """ def __init__(self, From b732d72ad84f9ef4db782feecd2709d0fa7e4f3b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 22 Sep 2024 14:24:44 +0000 Subject: [PATCH 31/35] Try to fix CUDA reinitialization error --- tests/multimodal/test_processor_kwargs.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index d7fa32a7f214..5529ccd4fa57 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -7,7 +7,6 @@ from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry -from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -65,6 +64,13 @@ def custom_dummy_data_factory(self, yield +# Lazy import to avoid CUDA reinitialization error +def mm_model_cls(): + from vllm.model_executor.models.phi3v import Phi3VForCausalLM + + return Phi3VForCausalLM + + # lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { @@ -199,7 +205,7 @@ def test_max_tokens_kwarg_overrides(num_crops): with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {Phi3VForCausalLM: get_num_crops}, + {mm_model_cls(): get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) @@ -234,7 +240,7 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {Phi3VForCausalLM: get_num_crops}, + {mm_model_cls(): get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) @@ -288,7 +294,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): with patch.object( mm_registry._get_plugin("image"), "_default_input_mapper", - {Phi3VForCausalLM: custom_mapper}, + {mm_model_cls(): custom_mapper}, ): mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) @@ -326,7 +332,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, with patch.object( mm_registry._get_plugin("image"), "_default_input_mapper", - {Phi3VForCausalLM: custom_mapper}, + {mm_model_cls(): custom_mapper}, ): mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) From 844524a9c15b717f2f23d44465c14073043b6307 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 02:34:56 +0000 Subject: [PATCH 32/35] Consolidate processor loading --- vllm/multimodal/image.py | 2 +- vllm/multimodal/video.py | 2 +- vllm/transformers_utils/image_processor.py | 71 ---------------------- vllm/transformers_utils/processor.py | 65 ++++++++++++++++++-- 4 files changed, 63 insertions(+), 77 deletions(-) delete mode 100644 vllm/transformers_utils/image_processor.py diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index d71e24d71f2e..31b1c3f93411 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -6,7 +6,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 75216df451b3..39e75dbaf687 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_video_processor +from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import is_list_of diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py deleted file mode 100644 index 61b338972e0f..000000000000 --- a/vllm/transformers_utils/image_processor.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import cast - - -def get_video_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """ - Gets a processor for the given model name via HuggingFace. - """ - from transformers import AutoProcessor - - try: - processor = AutoProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs, - ) - video_processor = processor.video_processor - - except ValueError as e: - if not trust_remote_code: - err_msg = ( - "Failed to load the processor. If the processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - return video_processor - - -def get_image_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """Gets an image processor for the given model name via HuggingFace.""" - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoImageProcessor - from transformers.image_processing_utils import BaseImageProcessor - - try: - processor = AutoImageProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return cast(BaseImageProcessor, processor) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 2001746c5f7f..98663f7f0bd0 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,13 +1,13 @@ -from typing import cast +from typing import Any, cast def get_processor( processor_name: str, - *args, + *args: Any, trust_remote_code: bool = False, - **kwargs, + **kwargs: Any, ): - """Gets a processor for the given model name via HuggingFace.""" + """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor @@ -35,3 +35,60 @@ def get_processor( raise e return cast(ProcessorMixin, processor) + + +def get_image_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseImageProcessor, processor) + + +def get_video_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers.image_processing_utils import BaseImageProcessor + + processor = get_processor( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cast(BaseImageProcessor, processor.video_processor) From 2dd742b7e295760f354634ea98f715725675b8db Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 03:00:59 +0000 Subject: [PATCH 33/35] Fix CUDA reinitialization error --- vllm/inputs/registry.py | 52 ++++++++++++----------------------------- 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 830b8781f058..576560bf2e0e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -236,19 +236,27 @@ def wrapper(model_cls: N) -> N: return wrapper - def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", - processor: InputProcessor, - **mm_processor_kwargs: Any) -> LLMInputs: + def process_input(self, model_config: "ModelConfig", + inputs: LLMInputs) -> LLMInputs: """ - Apply an input processor to an instance of model inputs. This will - usually not be invoked be directly, and instead will be wrapped in - a functools partial once the processor is created. + Apply an input processor to an instance of model inputs. The model is identified by ``model_config``. See also: :ref:`input_processing_pipeline` """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + processor = self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=processor, overrides=model_config.mm_processor_kwargs) + return processor(InputContext(model_config), inputs, **mm_processor_kwargs) @@ -257,34 +265,4 @@ def create_input_processor(self, model_config: "ModelConfig"): Create an input processor (see :meth:`_process_input`) for a specific model. """ - # Determine which kwargs can be leveraged for the input processor - # and drop + warn for kwargs that are unimplemented. - # NOTE: we don't allow override values for ctx/inputs, since doing - # so can lead to value collisions etc. - processor = self._get_model_input_processor(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, overrides=model_config.mm_processor_kwargs) - return functools.partial(self._process_input, - model_config=model_config, - processor=processor, - **mm_processor_kwargs) - - def _get_model_input_processor(self, model_config: "ModelConfig"): - """ - Grabs the input processor for the provided model. - - Args: - model_config: Config whose model architecture we can leverage to - grab the callable input processor. - - Returns: - Callable input processor for this model. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - - processor = self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) - return processor + return functools.partial(self.process_input, model_config) From ebc1c0267614fef840d319e15d5005908ea0449a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 03:03:15 +0000 Subject: [PATCH 34/35] Simplify code --- vllm/inputs/registry.py | 4 ++-- vllm/multimodal/base.py | 7 +++---- vllm/utils.py | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 576560bf2e0e..4d17fa7eb788 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -185,7 +185,7 @@ def dummy_data_for_profiling( mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_processor_kwargs = get_allowed_kwarg_only_overrides( - callable=dummy_factory, overrides=model_config.mm_processor_kwargs) + dummy_factory, overrides=model_config.mm_processor_kwargs) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -255,7 +255,7 @@ def process_input(self, model_config: "ModelConfig", .get(model_cls, self._default_input_processor) mm_processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, overrides=model_config.mm_processor_kwargs) + processor, overrides=model_config.mm_processor_kwargs) return processor(InputContext(model_config), inputs, **mm_processor_kwargs) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ee840caabe0b..87d3a4576f33 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -260,9 +260,9 @@ def map_input(self, model_config: ModelConfig, # Only get processor kwargs at mapping time if we are not using the # input mapper; no overrides are used on the default here because they # should be passed to the huggingface resource at initialization time. - if mapper != self._default_input_mapper: + if mapper is not None and mapper != self._default_input_mapper: mm_processor_kwargs = get_allowed_kwarg_only_overrides( - callable=mapper, overrides=model_config.mm_processor_kwargs) + mapper, overrides=model_config.mm_processor_kwargs) else: mm_processor_kwargs = {} @@ -344,8 +344,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: if callable(max_mm_tokens): mm_processor_kwargs = get_allowed_kwarg_only_overrides( - callable=max_mm_tokens, - overrides=model_config.mm_processor_kwargs) + max_mm_tokens, overrides=model_config.mm_processor_kwargs) max_mm_tokens = max_mm_tokens(InputContext(model_config), **mm_processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index 29627ab9b77c..b10a3152e99e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1239,7 +1239,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def get_allowed_kwarg_only_overrides( - callable: Optional[Callable], + callable: Callable[..., object], overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: """ @@ -1259,7 +1259,7 @@ def get_allowed_kwarg_only_overrides( to overwrite one or more keyword only arguments when invoking the callable. """ - if not overrides or not callable: + if not overrides: return {} allowed_override_names = [ @@ -1276,7 +1276,7 @@ def get_allowed_kwarg_only_overrides( } # If anything is dropped, log a warning - dropped_keys = set(overrides) - set(filtered_overrides) + dropped_keys = overrides.keys() - filtered_overrides.keys() if dropped_keys: logger.warning( "The following intended overrides are not keyword-only args " From a7f32f50515583aa3baf13736b7dbb0c49e0ec45 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 03:57:18 +0000 Subject: [PATCH 35/35] Fix tests --- vllm/inputs/registry.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4d17fa7eb788..6ab23d1c4b76 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -158,6 +158,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): + return self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -180,10 +184,9 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) @@ -236,6 +239,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_model_input_processor(self, model_cls: Type[nn.Module]): + return self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + def process_input(self, model_config: "ModelConfig", inputs: LLMInputs) -> LLMInputs: """ @@ -250,9 +257,7 @@ def process_input(self, model_config: "ModelConfig", from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - - processor = self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) + processor = self._get_model_input_processor(model_cls) mm_processor_kwargs = get_allowed_kwarg_only_overrides( processor, overrides=model_config.mm_processor_kwargs)