From 678c291c3708549329559be5a21c13377ad58ced Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Sat, 16 Nov 2024 21:22:36 +0000 Subject: [PATCH 01/23] Patch multi_modal_placeholders to RequestOutput * confirm that `offline_inference_vision_language.py` and `benchmark_throughput.py` runs * FIXME: the placeholders in output, however, is empty - will fix in next commit Signed-off-by: Linkun Chen --- vllm/outputs.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index badf50d0602d..205a479e8355 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -5,6 +5,7 @@ from typing import Union from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -95,6 +96,7 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[List[int]], + multi_modal_placeholders: MultiModalPlaceholderDict, prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -107,6 +109,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.mutli_modal_placeholders = multi_modal_placeholders self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -141,6 +144,7 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, + multi_modal_placeholders=MultiModalPlaceholderDict(), prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, @@ -154,8 +158,7 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -198,8 +201,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) + output_text = seq.get_output_text_to_return(text_buffer_length, + delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -276,7 +279,8 @@ def from_seq_group( seq_group.set_finished_time(finished_time) init_args = (seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.multi_modal_placeholders, prompt_logprobs, + outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) @@ -293,6 +297,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"multi_modal_placeholders={self.mutli_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " From a1cdcb32801a478ac2240546ddc896017bd54318 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Sun, 17 Nov 2024 23:47:28 +0000 Subject: [PATCH 02/23] pipe multi_modal_placeholders from intput to final output * add test for pixtral * fix a typo Signed-off-by: Linkun Chen --- .../vision_language/test_pixtral.py | 82 ++++++++++++++++++- vllm/model_executor/models/pixtral.py | 16 +++- vllm/outputs.py | 11 +-- 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index d8a98a0f84d3..bbae8138d44b 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -6,15 +6,20 @@ import uuid from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from transformers import AutoProcessor import pytest +from mistral_common.multimodal import download_image from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk -from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt +from vllm import (EngineArgs, LLMEngine, SamplingParams, TokensPrompt, + TextPrompt, RequestOutput) +from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins +from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test @@ -23,6 +28,8 @@ if TYPE_CHECKING: from _typeshed import StrPath +logger = init_logger(__name__) + MODELS = ["mistralai/Pixtral-12B-2409"] IMG_URLS = [ "https://picsum.photos/id/237/400/300", @@ -49,6 +56,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: }] +def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "content": PROMPT, + }, *({ + "type": "image", + "image": download_image(url) + } for url in urls)], + }] + + def _create_engine_inputs(urls: List[str]) -> TokensPrompt: msg = _create_msg_format(urls) @@ -70,6 +91,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt: return engine_inputs +def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: + msg = _create_msg_format_hf(urls) + + tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b") + prompt = tokenizer.apply_chat_template(msg) + + images = [] + for chunk in msg[0]["content"]: + if chunk["type"] == "image": + images.append(chunk["image"]) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data) + + return engine_inputs + + MSGS = [ _create_msg_format(IMG_URLS[:1]), _create_msg_format(IMG_URLS[:2]), @@ -191,3 +229,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: outputs_1_lst=logprobs, name_0="h100_ref", name_1="output") + + +@large_gpu_test(min_gb=24) +@pytest.mark.parametrize( + "prompt,expected_ranges", + [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ + "offset": 10, + "length": 494 + }]), + (_create_engine_inputs_hf(IMG_URLS[1:4]), [{ + "offset": 10, + "length": 266 + }, { + "offset": 276, + "length": 1056 + }, { + "offset": 1332, + "length": 418 + }])]) +def test_multi_modal_placeholders( + vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None: + with vllm_runner( + "mistral-community/pixtral-12b", + max_model_len=8192, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = vllm_model.model.generate(prompt) + + assert len(outputs) == 1, f"{len(outputs)=}" + output: RequestOutput = outputs[0] + assert hasattr(output, + "multi_modal_placeholders"), f"{output.__dict__=}" + assert "image" in output.multi_modal_placeholders, \ + f"{output.multi_modal_placeholders.keys()=}" + image_placeholder_ranges: list[ + PlaceholderRange] = output.multi_modal_placeholders["image"] + assert len(image_placeholder_ranges) == len( + expected_ranges), f"{image_placeholder_ranges=}" + for real_range, expected_range in zip(image_placeholder_ranges, + expected_ranges): + assert real_range == expected_range, \ + f"{real_range=} {expected_range=}" diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index a3e30ea2dd29..790a260d43ec 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -30,6 +30,7 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData @@ -773,15 +774,28 @@ def input_processor_for_pixtral_hf( replace_tokens[-1] = image_end_id replace_tokens_list.append(replace_tokens) + reverse_offsets: List[int] = [] # Backward iteration for replacement without affecting known indices for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices), reversed(replace_tokens_list)): + reverse_offsets.append( + len(new_token_ids) - placeholder_idx + len(replace_tokens)) new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens + placeholder_ranges: List[PlaceholderRange] = [] + for reverse_offset, replace_tokens in zip(reversed(reverse_offsets), + replace_tokens_list): + placeholder_ranges.append( + PlaceholderRange( + offset=len(new_token_ids) - reverse_offset, + length=len(replace_tokens), + )) + # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) class PixtralHFMLP(nn.Module): diff --git a/vllm/outputs.py b/vllm/outputs.py index 205a479e8355..a02f1e97b5b5 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -109,7 +109,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.mutli_modal_placeholders = multi_modal_placeholders + self.multi_modal_placeholders = multi_modal_placeholders self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -158,7 +158,8 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -201,8 +202,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return(text_buffer_length, - delta) + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -297,7 +298,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"multi_modal_placeholders={self.mutli_modal_placeholders}, " + f"multi_modal_placeholders={self.multi_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " From f60964a976b664ce9ddaf286e3a8cf8ba0850524 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 16 Nov 2024 10:45:26 -0800 Subject: [PATCH 03/23] [V1] Add code owners for V1 (#10397) Signed-off-by: Woosuk Kwon Signed-off-by: Linkun Chen --- .github/CODEOWNERS | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cd721971d01d..3cb91fc0f823 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,13 +3,16 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -CMakeLists.txt @tlrmchlsmth @WoosukKwon +/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +CMakeLists.txt @tlrmchlsmth + +# vLLM V1 +/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic # Test ownership /tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo From 578e482b3eef1b816a7a32b943adf481b279018b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 16 Nov 2024 18:02:14 -0800 Subject: [PATCH 04/23] [2/N][torch.compile] make compilation cfg part of vllm cfg (#10383) Signed-off-by: youkaichao Signed-off-by: Linkun Chen --- tests/compile/piecewise/test_simple.py | 8 +- tests/compile/piecewise/test_toy_llama.py | 22 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion.py | 2 +- tests/compile/test_wrapper.py | 4 +- tests/compile/utils.py | 2 +- .../model_executor/test_enabled_custom_ops.py | 52 ++--- tests/tpu/test_compilation.py | 2 +- tests/tpu/test_custom_dispatcher.py | 2 +- vllm/compilation/backends.py | 20 +- vllm/compilation/config.py | 159 --------------- vllm/compilation/decorators.py | 10 +- vllm/compilation/fusion.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/compilation/levels.py | 8 - vllm/compilation/wrapper.py | 11 +- vllm/config.py | 189 ++++++++++++++++++ vllm/envs.py | 13 -- vllm/model_executor/custom_op.py | 27 +-- vllm/model_executor/model_loader/loader.py | 7 +- vllm/platforms/interface.py | 20 +- vllm/platforms/tpu.py | 21 +- vllm/plugins/__init__.py | 30 ++- vllm/v1/worker/gpu_model_runner.py | 10 +- vllm/worker/model_runner.py | 7 +- vllm/worker/tpu_model_runner.py | 8 +- 27 files changed, 359 insertions(+), 283 deletions(-) delete mode 100644 vllm/compilation/config.py delete mode 100644 vllm/compilation/levels.py diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index c631850ecded..45f56cbbd4b1 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op global_counter = 0 @@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - model = SillyModel(vllm_config=VllmConfig(), prefix='') + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c363a587a818..8032304e9580 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -15,12 +15,10 @@ from torch.library import Library from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig -from vllm.plugins import set_compilation_config +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -272,9 +270,11 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -395,9 +395,11 @@ def benchmark(): else: set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda().to(torch.bfloat16) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 833589ba5dc9..08747ebc58b7 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -3,7 +3,7 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f00334934cb4..4dfdfe21a67d 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,6 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb95..4db79b070fd8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,10 +3,10 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3668c1fab6b8..74f66baaa5ea 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -3,6 +3,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel class MyMod(torch.nn.Module): @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) + super().__init__(compiled_callable, + compilation_level=CompilationLevel.DYNAMO_ONCE) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 222c63a342a4..729f10676888 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index af267f804ffa..c3219bc50646 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -3,11 +3,13 @@ import pytest +from vllm.config import CompilationConfig, VllmConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + assert CustomOp.default_on() == default_on - # Reset default_on (computed once): - CustomOp.default_on.cache_clear() + ops_enabled = [bool(x) for x in ops_enabled] - assert CustomOp.default_on() == default_on + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] - ops_enabled = [bool(x) for x in ops_enabled] + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - - # Unregistered subclass - class SiluAndMul2(SiluAndMul): - pass - - # Subclasses should not require registration - assert SiluAndMul2().enabled() == SiluAndMul().enabled() + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_CUSTOM_OPS"] = env - CustomOp.default_on.cache_clear() - - with pytest.raises(AssertionError): - RMSNorm(1024).enabled() + with pytest.raises(Exception): # noqa + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + RMSNorm(1024).enabled() diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 86d9af88e49e..941abe17a337 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,7 +5,7 @@ import depyf -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel # disable custom dispatcher, let Dynamo takes over # all the control diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 923d0f168080..53b10c06135a 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,6 +1,6 @@ import os -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import compare_two_settings diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5682faa15806..22c613931f08 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -10,13 +10,12 @@ import torch.fx as fx import vllm.envs as envs +from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors -from .config import CompilationConfig from .counter import compilation_counter from .fusion import FusionPass -from .levels import CompilationLevel from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -392,7 +391,10 @@ class VllmBackend: sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - def __init__(self, post_grad_passes: Sequence[Callable] = ()): + def __init__( + self, + compilation_configs: CompilationConfig, + ): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = post_grad_passes + self.post_grad_passes = [] self.sym_tensor_indices = [] self.input_buffers = [] + self.compilation_configs = compilation_configs + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is read now, because only here can + # config is updated now, because only here can # we get the sizes to capture for cudagraph # from compilation context - self.compilation_configs = CompilationConfig.select_and_init_config() + self.compilation_configs.init_during_runtime() self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - return VllmBackend() + from vllm.plugins import get_current_vllm_config + compilation_config = get_current_vllm_config().compilation_config + return VllmBackend(compilation_config) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py deleted file mode 100644 index 3e663505c627..000000000000 --- a/vllm/compilation/config.py +++ /dev/null @@ -1,159 +0,0 @@ -import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, PrivateAttr - -import vllm.envs as envs -from vllm.logger import init_logger - -from .compile_context import get_compile_context - -logger = init_logger(__name__) - - -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has two parts: - - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations - in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - Custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) - - use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) - cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = False - - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - - # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr - - def model_post_init(self, __context: Any) -> None: - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func - - def init_during_runtime(self): - """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context - if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize - else: - self.capture_sizes = self.cudagraph_capture_sizes - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") - self.compile_sizes = self.inductor_compile_sizes - - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - config.init_during_runtime() - return config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca1e96a33c01..4b78491bc5a4 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,10 +3,8 @@ import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + self.do_not_compile = \ + vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() if self.do_not_compile: return - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ # type: ignore diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index eb43604b1399..e6a3afef85e1 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,8 +6,8 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass +from vllm.config import CompilationConfig from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b23351fa1975..8082a08b4001 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py deleted file mode 100644 index 19a3a2b52687..000000000000 --- a/vllm/compilation/levels.py +++ /dev/null @@ -1,8 +0,0 @@ -# constants for the levels of the compilation process - - -class CompilationLevel: - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7366ed4d16b0..2a1aecc11ce2 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,8 +8,7 @@ import torch import vllm.envs as envs - -from .levels import CompilationLevel +from vllm.config import CompilationLevel class TorchCompileWrapperWithCustomDispatcher: @@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Optional[Callable] = None): + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): if compiled_callable is None: # default compilation settings @@ -38,7 +39,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = get_torch_compile_backend() if backend is None: from vllm.compilation.backends import select_default_backend - backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + backend = select_default_backend(compilation_level) compiled_callable = torch.compile( self.forward, @@ -54,7 +55,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + compilation_level >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/config.py b/vllm/config.py index 64b2f75e092d..7e37edbe594b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,10 +3,12 @@ import json import warnings from dataclasses import dataclass, field, replace +from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch +from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs @@ -2052,6 +2054,185 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + custom_ops: List[str] = Field(default_factory=list) + + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False + + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + self.level = envs.VLLM_TORCH_COMPILE_LEVEL + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_compile_config[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + from vllm.compilation.compile_context import get_compile_context + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + return config + + @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2073,6 +2254,8 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2133,6 +2316,12 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.compilation_config is None: + self.compilation_config = CompilationConfig.select_and_init_config( + ) + + current_platform.check_and_update_config(self) + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index f320e35971f9..716e835a555f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,7 +69,6 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None - VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -217,18 +216,6 @@ def get_default_config_root(): "VLLM_TORCH_COMPILE_CONFIG": lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), - # Fine-grained control over which custom ops to enable/disable. - # Use 'all' to enable all, 'none' to disable all. - # Also specify a list of custom op names to enable (prefixed with a '+'), - # or disable (prefixed with a '-'). - # Examples: - # - 'all,-op1' to enable all except op1 - # - 'none,+op1,+op2' to enable only op1 and op2 - # By default, all custom ops are enabled when running without Inductor - # and disabled when running with Inductor (compile_level >= Inductor). - "VLLM_CUSTOM_OPS": - lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 24d75f4df4e0..6ae7d7cf6964 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,12 +1,10 @@ -from functools import lru_cache from typing import Dict, Type import torch.nn as nn -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -87,6 +85,8 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): print_warning_once( f"Custom op {cls.__name__} was not registered, " @@ -94,22 +94,25 @@ def enabled(cls) -> bool: f"It will be enabled/disabled based on the global settings.") return CustomOp.default_on() - enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS - disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE - # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod - @lru_cache def default_on() -> bool: - count_none = envs.VLLM_CUSTOM_OPS.count("none") - count_all = envs.VLLM_CUSTOM_OPS.count("all") - assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + count_none = custom_ops.count("none") + count_all = custom_ops.count("all") + return compilation_config.level < CompilationLevel.PIECEWISE and \ not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 140b61fe6d56..0f8b81c3ef40 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,6 +42,7 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - return model_class(vllm_config=vllm_config, prefix=prefix) + with set_current_vllm_config(vllm_config): + return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " @@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - return model_class(**kwargs) + with set_current_vllm_config(vllm_config): + return model_class(**kwargs) class BaseModelLoader(ABC): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 81d8bdae2383..970c0d1be617 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,10 +1,15 @@ import enum import random -from typing import NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class PlatformEnum(enum.Enum): CUDA = enum.auto() @@ -129,6 +134,19 @@ def seed_everything(cls, seed: int) -> None: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + """ + Check and update the configuration for the current platform. + + It can raise an exception if the configuration is not compatible with + the current platform, or it can update the configuration to make it + compatible with the current platform. + + The config is passed by reference, so it can be modified in place. + """ + pass + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8d0ce47df404..c2e22bfc09f2 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,18 +1,16 @@ import os +from typing import TYPE_CHECKING import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend from .interface import Platform, PlatformEnum -if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) - -assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None set_torch_compile_backend("openxla") @@ -31,3 +29,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + compilation_config = vllm_config.compilation_config + if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + compilation_config.level = CompilationLevel.DYNAMO_ONCE + assert compilation_config.level < CompilationLevel.PIECEWISE,\ + "TPU does not support Inductor." diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7b1bbb14c530..c20b9ec891d5 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,11 +1,11 @@ import logging +from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs if TYPE_CHECKING: - from vllm.compilation.config import CompilationConfig - from vllm.config import VllmConfig + from vllm.config import CompilationConfig, VllmConfig else: CompilationConfig = None VllmConfig = None @@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + try: + _current_vllm_config = vllm_config + yield + finally: + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + assert _current_vllm_config is not None, "Current VLLM config is not set." + return _current_vllm_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebd1de96537..d60f93a44f6d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,3 @@ -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -8,11 +7,8 @@ import torch.distributed import torch.nn as nn -from vllm import envs from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -99,7 +95,7 @@ def __init__( pin_memory=self.pin_memory, ) - self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -517,9 +513,9 @@ def load_model(self) -> None: # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. - os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( + custom_ops=["none"], use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], use_inductor=True, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 22ee3f9f863e..fd89f9544556 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -19,8 +19,7 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -1142,8 +1141,8 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ - and supports_dynamo(): + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() or "eager" self.model = torch.compile( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a72118613732..d7a641857a61 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -140,7 +140,7 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model) + self.model = ModelWrapper(model, self.vllm_config) def _dummy_run( self, @@ -669,13 +669,15 @@ def execute_model( class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, vllm_config: VllmConfig): self.model = model compiled_callable = torch.compile(self.forward, backend="openxla", fullgraph=True, dynamic=False) - super().__init__(compiled_callable) + super().__init__( + compiled_callable, + compilation_level=vllm_config.compilation_config.level) def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: From 7fa97cf556eb753b30f5d83b85e6f2b2d3619f62 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 16 Nov 2024 21:18:46 -0800 Subject: [PATCH 05/23] [V1] Refactor model executable interface for all text-only language models (#10374) Signed-off-by: Roger Wang Signed-off-by: Linkun Chen --- vllm/model_executor/models/arctic.py | 16 ++++++++++++++-- vllm/model_executor/models/baichuan.py | 16 ++++++++++++++-- vllm/model_executor/models/bloom.py | 17 ++++++++++++++--- vllm/model_executor/models/commandr.py | 16 ++++++++++++++-- vllm/model_executor/models/dbrx.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek_v2.py | 16 ++++++++++++++-- vllm/model_executor/models/eagle.py | 13 ++++++++++--- vllm/model_executor/models/exaone.py | 7 ++++++- vllm/model_executor/models/falcon.py | 16 ++++++++++++++-- vllm/model_executor/models/gemma.py | 7 ++++++- vllm/model_executor/models/gemma2.py | 12 ++++++++++-- vllm/model_executor/models/gpt2.py | 7 +++++-- vllm/model_executor/models/gpt_bigcode.py | 17 +++++++++++++---- vllm/model_executor/models/gpt_j.py | 16 ++++++++++++++-- vllm/model_executor/models/gpt_neox.py | 16 ++++++++++++++-- vllm/model_executor/models/granite.py | 7 ++++++- vllm/model_executor/models/granitemoe.py | 16 ++++++++++++++-- vllm/model_executor/models/internlm2.py | 9 +++++++-- vllm/model_executor/models/jais.py | 14 ++++++++++++-- vllm/model_executor/models/jamba.py | 16 ++++++++++++++-- vllm/model_executor/models/mamba.py | 15 +++++++++++++-- vllm/model_executor/models/minicpm.py | 7 ++++++- vllm/model_executor/models/mixtral.py | 16 ++++++++++++++-- vllm/model_executor/models/mixtral_quant.py | 16 ++++++++++++++-- vllm/model_executor/models/mpt.py | 16 ++++++++++++++-- vllm/model_executor/models/nemotron.py | 7 ++++++- vllm/model_executor/models/olmo.py | 19 +++++++++++++------ vllm/model_executor/models/olmoe.py | 16 ++++++++++++++-- vllm/model_executor/models/orion.py | 16 ++++++++++++++-- vllm/model_executor/models/persimmon.py | 8 +++++++- vllm/model_executor/models/phi.py | 16 ++++++++++++++-- vllm/model_executor/models/phi3_small.py | 19 +++++++++++-------- vllm/model_executor/models/phimoe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2.py | 2 +- vllm/model_executor/models/qwen2_cls.py | 7 ++++++- vllm/model_executor/models/qwen2_moe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2_rm.py | 7 ++++++- vllm/model_executor/models/solar.py | 4 +++- vllm/model_executor/models/stablelm.py | 16 ++++++++++++++-- vllm/model_executor/models/starcoder2.py | 16 ++++++++++++++-- vllm/model_executor/models/xverse.py | 16 ++++++++++++++-- 43 files changed, 483 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 9ee2a2cc09a2..d52418ee0f6f 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -389,6 +389,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -396,9 +399,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -439,6 +446,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -446,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index aabbd31192a4..01ce7c42cd39 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -284,6 +284,9 @@ def __init__( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -291,9 +294,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 84adf574af5e..cf2eee817276 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -251,6 +251,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + def forward( self, input_ids: torch.Tensor, @@ -258,10 +261,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -301,6 +307,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -308,9 +317,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index cd5c1d684471..fbb09a64cde9 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -280,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -287,9 +290,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -354,6 +361,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + @torch.no_grad() def forward( self, @@ -362,9 +372,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index fff8710f6b47..3952ff31e5ce 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -321,6 +321,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -328,9 +331,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] @@ -376,6 +383,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -383,9 +393,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index a9bf1440c4d6..36dfea5a6565 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -353,6 +353,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -360,9 +363,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -401,6 +408,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,9 +418,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4fb1eed15a2e..1e32fe60c7a5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -445,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -452,9 +455,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -495,6 +502,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -502,9 +512,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 85c51e840458..f138d1363026 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -78,6 +78,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def sampler(self): return self.model.sampler + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -86,11 +89,14 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - tok_embeds = self.model.model.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.fc( - torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) inputs_embeds[positions == 0] = 0 # masking inputs at position=0 @@ -100,7 +106,8 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors) + intermediate_tensors=intermediate_tensors, + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index cd3e7da657e0..52dd603ca558 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -479,6 +479,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -486,9 +489,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b3dbf063ac29..e97abe949ccd 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -367,6 +367,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -374,9 +377,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -432,6 +439,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.LongTensor, @@ -439,9 +449,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 55baba809e58..ace13664c6ea 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -390,6 +390,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -397,9 +400,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index eeb3fd98a7ea..a60b4e73a76d 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -272,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], @@ -285,7 +288,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.normalizer residual = None else: @@ -414,6 +417,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -421,9 +427,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index cc85693f9952..fa0fdad28d16 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -209,6 +209,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,7 +223,7 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.get_input_embeddings(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -262,7 +265,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.transformer.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.wte(input_ids) + return self.transformer.get_input_embeddings(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ab25c66c3a88..b2fc79d0d36d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,11 +228,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + self.wpe(position_ids) else: hidden_states = intermediate_tensors["hidden_states"] @@ -285,6 +289,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -292,9 +299,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index a83d03480dde..cec3fd12a67d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -201,6 +201,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -208,9 +211,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -250,6 +257,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +267,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 794b141bfa4a..11f286d6bcba 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -214,6 +214,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_in(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -221,9 +224,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_in(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -262,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -269,9 +279,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index d1e6e31f2b8d..cb2583e69d88 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -409,6 +409,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -416,9 +419,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 2ed115c56af4..f437dd521a7d 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -277,6 +277,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -284,9 +287,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier residual = None else: @@ -366,6 +373,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -373,9 +383,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 21fa6983063b..19bfe16e4d5f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -290,7 +290,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -335,6 +335,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +345,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 65800c44e5a9..ee49ffb3cd87 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -250,6 +250,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +260,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) if self.wpe is not None: position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -311,6 +316,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -318,9 +326,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 88fb8d5cf555..5612dd688638 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -292,6 +292,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -299,8 +302,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -381,12 +388,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -409,7 +420,8 @@ def forward(self, mamba_cache_tensors[1], state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + attn_metadata, mamba_cache_params, + inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 55c575e22a0f..ac0d265a961f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -106,15 +106,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): @@ -168,12 +175,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -194,7 +205,7 @@ def forward(self, state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params) + mamba_cache_params, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 2db953329fd9..6b67266c5336 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -504,6 +504,9 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = MiniCPMModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -511,9 +514,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3eb2f60fd4fc..eebf5bab5a28 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -281,6 +281,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -288,9 +291,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 95cfb6f54dc1..af2e9586988d 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -318,6 +318,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -325,9 +328,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -368,6 +375,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -375,9 +385,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index e15c0fe8db06..3c74ef2448ab 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -283,6 +290,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -290,9 +300,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e09d7088a69c..eb45beae7d21 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -440,6 +440,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -447,9 +450,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3467ae589649..98d4e1ec320a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -248,6 +248,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -255,17 +258,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - hidden_states = inputs_embeds + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -315,6 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -322,6 +327,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, @@ -329,6 +335,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) return hidden_states diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 3d31919edd86..f4eebab8c98d 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -269,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -276,9 +279,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -326,6 +333,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -333,9 +343,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 38821c828834..39d659c49cbc 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "hidden_states", ], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -286,6 +293,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -293,9 +303,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 2e34a7cc3087..62c509153a11 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -248,7 +251,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 262f6996fc37..a2ab0d74c48d 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -303,6 +310,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -310,9 +320,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 8a5fb6d303e6..2139cec44180 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -337,9 +334,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier @@ -397,8 +398,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.dummy_token_indices = None - def get_input_embeddings(self): - return self.model.embed_tokens + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def set_input_embeddings(self, value): self.model.embed_tokens = value @@ -433,6 +434,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, @@ -440,6 +442,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) output_hidden_states = output_hidden_states return output_hidden_states diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 6d71a8949111..b7e70f8fa2c6 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -465,6 +465,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -472,9 +475,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -560,6 +567,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -567,9 +577,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3d26ede722dd..447632cefcd9 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -578,6 +578,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config) if hasattr( config, "visual") else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -586,6 +589,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[QwenImageInputs], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: img_pos = None # If pixel / visual embeddings are provided, this is a visual model @@ -606,6 +610,10 @@ def forward( ) if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.wte(input_ids) # Merge the image embeddings into the hidden states if actually have # visual features and the corresponding image tokens @@ -915,6 +923,9 @@ def _get_image_input_type( ) return None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -922,7 +933,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None + pixel_values: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: input_ids = None @@ -932,7 +944,7 @@ def forward( hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - pixel_values) + pixel_values, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 431e397e1e10..8f10df808c21 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -309,7 +309,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 120403e94868..07eb330620a4 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -72,6 +72,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): normalize=False, softmax=True) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -79,9 +82,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 51c0cd5664fd..249d94b5d95e 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -344,6 +344,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -351,9 +354,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -395,6 +402,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -402,9 +412,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 55843d832534..6db467af334f 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -85,6 +85,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -92,9 +95,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 4f03ca501fb6..affb2c975ce4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -456,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1125f9e9f961..99acce596602 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -265,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -272,9 +282,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index ce7a7957f52c..0ef940acebb9 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -221,6 +221,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -228,9 +231,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -273,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -280,9 +290,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 153527da20d7..51172d8782a7 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -252,6 +252,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -259,9 +262,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -335,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +352,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( From 629f5120b70cb1426d18c5b8eb19f779702ef367 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sun, 17 Nov 2024 00:58:22 -0600 Subject: [PATCH 06/23] [CI/Build] Fix IDC hpu [Device not found] issue (#10384) Signed-off-by: Chendi Xue Signed-off-by: Linkun Chen --- .buildkite/run-hpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh index 4505dc7a9373..fa4f74fca7a1 100644 --- a/.buildkite/run-hpu-test.sh +++ b/.buildkite/run-hpu-test.sh @@ -13,4 +13,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --runtime=habana --name=hpu-test --network=host -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file From 7539ab8d83dfe4357caf1deda4c29201b698a3d5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 17 Nov 2024 15:12:04 +0800 Subject: [PATCH 07/23] [Bugfix][CPU] Fix CPU embedding runner with tensor parallel (#10394) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Linkun Chen --- vllm/worker/cpu_embedding_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index 7053075bf4d8..d0b8fec48d74 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -66,6 +66,10 @@ def execute_model( hidden_states = model_executable(**execute_model_kwargs) + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + return [ self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) From bce660d722520d7e236e6bdd8e5ecbb102900972 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 16 Nov 2024 23:14:23 -0800 Subject: [PATCH 08/23] [platforms] refactor cpu code (#10402) Signed-off-by: youkaichao Signed-off-by: Linkun Chen --- vllm/executor/cpu_executor.py | 68 +---------------------------------- vllm/platforms/cpu.py | 60 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 67 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 4ceb5a837dd7..1542a2ae367e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -2,9 +2,6 @@ from functools import partial from typing import Any, Awaitable, List, Optional, Set, Tuple, Union -import vllm.envs as envs -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) @@ -13,7 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, +from vllm.utils import (get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -57,13 +54,6 @@ def _init_executor(self) -> None: os.environ["LOCAL_WORLD_SIZE"] = str( self.parallel_config.tensor_parallel_size) - self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config(self.cache_config) - self.scheduler_config = _verify_and_get_scheduler_config( - self.scheduler_config) - self.parallel_config = _verify_and_get_parallel_config( - self.parallel_config) - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -313,62 +303,6 @@ async def check_health_async(self) -> None: self.check_health() -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on CPU, fallback to the eager " - "mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_scheduler_config( - config: SchedulerConfig) -> SchedulerConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if config.chunked_prefill_enabled: - logger.warning("Chunked prefill is not supported on CPU, disable it.") - config.chunked_prefill_enabled = False - - return config - - -def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if config.enable_prefix_caching: - logger.warning("Prefix caching is not supported on CPU, disable it.") - config.enable_prefix_caching = False - - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE - - if kv_cache_space >= 0: - if kv_cache_space == 0: - config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore - logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " - "for CPU backend is not set, using 4 by default.") - else: - config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore - else: - raise RuntimeError( - "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" - f" {kv_cache_space}, expect a positive integer value.") - - return config - - -def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: - if (config.distributed_executor_backend is not None - and config.distributed_executor_backend != "mp"): - logger.warning( - "%s is not supported on CPU, fallback to mp distributed executor " - "backend.", config.distributed_executor_backend) - config.distributed_executor_backend = "mp" - return config - - def _driver_method_invoker(driver, method: str, *args, **kwargs): return getattr(driver, method)(*args, **kwargs) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 5243f59203af..42bee31dfb0e 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,8 +1,19 @@ +from typing import TYPE_CHECKING + import psutil import torch +from vllm.logger import init_logger + from .interface import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + class CpuPlatform(Platform): _enum = PlatformEnum.CPU @@ -18,3 +29,52 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + import vllm.envs as envs + from vllm.utils import GiB_bytes + model_config = vllm_config.model_config + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True + + cache_config = vllm_config.cache_config + + if cache_config.enable_prefix_caching: + logger.warning( + "Prefix caching is not supported on CPU, disable it.") + cache_config.enable_prefix_caching = False + + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + + if kv_cache_space >= 0: + if kv_cache_space == 0: + cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.chunked_prefill_enabled: + logger.warning( + "Chunked prefill is not supported on CPU, disable it.") + scheduler_config.chunked_prefill_enabled = False + + parallel_config = vllm_config.parallel_config + if (parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "mp"): + logger.warning(("%s is not supported on CPU, fallback to mp " + "distributed executor backend."), + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "mp" From 305708bd04dbe1146a349935ebe628873157879c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sun, 17 Nov 2024 16:44:44 +0800 Subject: [PATCH 09/23] [Hardware] [HPU]add `mark_step` for hpu (#10239) Signed-off-by: Kunshang Ji Signed-off-by: Linkun Chen --- vllm/worker/hpu_model_runner.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 1ff30d685c6b..99cf9a7e6725 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -272,6 +272,19 @@ def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): return indices, offsets +def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): + if module.__class__.__name__.endswith(suffix): + + def forward_hook(module, args, output): + htorch.core.mark_step() + return output + + module.register_forward_hook(forward_hook) + + for child_name, child_module in module.named_children(): + modify_decoder_layer(child_module) + + class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager): @@ -636,6 +649,7 @@ def load_model(self) -> None: else: self.model = self.model.to("hpu") htcore.mark_step() + modify_decoder_layer(self.model) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 871a773c9543f154cbc03a73ec901ab4ca039458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B5=E8=84=91=E6=98=9F=E4=BA=BA?= Date: Sun, 17 Nov 2024 16:50:24 +0800 Subject: [PATCH 10/23] [Bugfix] Fix mrope_position_delta in non-last prefill chunk (#10403) Signed-off-by: imkero Signed-off-by: Linkun Chen --- vllm/model_executor/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b01e4c61fe10..117fe086e5e8 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -922,9 +922,9 @@ def get_input_positions( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] return llm_positions.tolist(), mrope_position_delta From 242bb53c234fa252f62c2667722f69c9fe362cf9 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Sun, 17 Nov 2024 19:32:40 +0800 Subject: [PATCH 11/23] =?UTF-8?q?[Misc]=20Enhance=20offline=5Finference=20?= =?UTF-8?q?to=20support=20user-configurable=20paramet=E2=80=A6=20(#10392)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wchen61 Signed-off-by: Linkun Chen --- examples/offline_inference.py | 98 ++++++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..391ac6b9b6b0 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,22 +1,80 @@ +from dataclasses import asdict + from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def get_prompts(num_prompts: int): + # The default sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if num_prompts != len(prompts): + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] + + return prompts + + +def main(args): + # Create prompts + prompts = get_prompts(args.num_prompts) + + # Create a sampling params object. + sampling_params = SamplingParams(n=args.n, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_tokens=args.max_tokens) + + # Create an LLM. + # The default model is 'facebook/opt-125m' + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**asdict(engine_args)) + + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + group = parser.add_argument_group("SamplingParams options") + group.add_argument("--num-prompts", + type=int, + default=4, + help="Number of prompts used for inference") + group.add_argument("--max-tokens", + type=int, + default=16, + help="Generated output length for sampling") + group.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt') + group.add_argument('--temperature', + type=float, + default=0.8, + help='Temperature for text generation') + group.add_argument('--top-p', + type=float, + default=0.95, + help='top_p for text generation') + group.add_argument('--top-k', + type=int, + default=-1, + help='top_k for text generation') -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + args = parser.parse_args() + main(args) From f5312d38c44144c518a1b049dec6de1cd2b9c7ce Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:11:01 +0800 Subject: [PATCH 12/23] Fix initialization Signed-off-by: Cyrus Leung Signed-off-by: Linkun Chen --- vllm/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index a02f1e97b5b5..32160a8c0432 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -144,7 +144,7 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - multi_modal_placeholders=MultiModalPlaceholderDict(), + multi_modal_placeholders={}, prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, From 439e324269d343be3b9ba3b45e9e7d8a5f3a0d8b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:13:07 +0800 Subject: [PATCH 13/23] Run isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index bbae8138d44b..c502575d4f92 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -15,8 +15,8 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk -from vllm import (EngineArgs, LLMEngine, SamplingParams, TokensPrompt, - TextPrompt, RequestOutput) +from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, + TokensPrompt, TextPrompt) from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange From 60815f258a3983e3071387da7194b407ce6f52e4 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:16:04 +0800 Subject: [PATCH 14/23] isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index c502575d4f92..dad14f9fe3c8 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -6,7 +6,6 @@ import uuid from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from transformers import AutoProcessor import pytest from mistral_common.multimodal import download_image @@ -14,6 +13,7 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk +from transformers import AutoProcessor from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, TokensPrompt, TextPrompt) From ec4675566c9ce859fc0988ea6f2dc76dcfa2a58a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:18:22 +0800 Subject: [PATCH 15/23] isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index dad14f9fe3c8..4edfd6862511 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -16,7 +16,7 @@ from transformers import AutoProcessor from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, - TokensPrompt, TextPrompt) + TextPrompt, TokensPrompt) from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange From ce3ae6fb0cdeee32b4742b4a5865e52dc4959c7e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 09:07:46 +0800 Subject: [PATCH 16/23] [Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Linkun Chen --- vllm/model_executor/model_loader/loader.py | 12 +++++++++++- vllm/model_executor/models/arctic.py | 8 ++++++-- vllm/model_executor/models/baichuan.py | 8 ++++++-- vllm/model_executor/models/bert.py | 8 ++++++-- vllm/model_executor/models/blip.py | 12 ++++++++---- vllm/model_executor/models/blip2.py | 7 ++++--- vllm/model_executor/models/bloom.py | 8 ++++++-- vllm/model_executor/models/chameleon.py | 8 ++++++-- vllm/model_executor/models/chatglm.py | 10 ++++++++-- vllm/model_executor/models/clip.py | 11 ++++++++--- vllm/model_executor/models/commandr.py | 4 +++- vllm/model_executor/models/dbrx.py | 8 ++++++-- vllm/model_executor/models/decilm.py | 8 ++++++-- vllm/model_executor/models/deepseek.py | 8 ++++++-- vllm/model_executor/models/deepseek_v2.py | 8 ++++++-- vllm/model_executor/models/exaone.py | 9 +++++++-- vllm/model_executor/models/falcon.py | 8 ++++++-- vllm/model_executor/models/florence2.py | 17 +++++++++++------ vllm/model_executor/models/fuyu.py | 8 +++++--- vllm/model_executor/models/gemma.py | 4 +++- vllm/model_executor/models/gemma2.py | 9 ++++++--- vllm/model_executor/models/gpt2.py | 8 ++++++-- vllm/model_executor/models/gpt_bigcode.py | 8 ++++++-- vllm/model_executor/models/gpt_j.py | 8 ++++++-- vllm/model_executor/models/gpt_neox.py | 8 ++++++-- vllm/model_executor/models/granite.py | 9 +++++++-- vllm/model_executor/models/granitemoe.py | 8 +++++--- .../models/idefics2_vision_model.py | 11 ++++++++--- vllm/model_executor/models/idefics3.py | 7 ++++--- vllm/model_executor/models/intern_vit.py | 8 ++++++-- vllm/model_executor/models/internlm2.py | 8 ++++++-- vllm/model_executor/models/internvl.py | 7 ++++--- vllm/model_executor/models/jais.py | 8 ++++++-- vllm/model_executor/models/jamba.py | 8 ++++++-- vllm/model_executor/models/llama.py | 15 ++++++++++----- vllm/model_executor/models/llava.py | 7 ++++--- vllm/model_executor/models/llava_next.py | 7 ++++--- vllm/model_executor/models/llava_next_video.py | 7 ++++--- vllm/model_executor/models/llava_onevision.py | 7 ++++--- vllm/model_executor/models/mamba.py | 8 ++++++-- vllm/model_executor/models/medusa.py | 9 +++++++-- vllm/model_executor/models/minicpm.py | 8 ++++++-- vllm/model_executor/models/minicpmv.py | 14 +++++++++----- vllm/model_executor/models/mixtral.py | 8 ++++++-- vllm/model_executor/models/mixtral_quant.py | 8 ++++++-- vllm/model_executor/models/mllama.py | 9 ++++++--- vllm/model_executor/models/mlp_speculator.py | 8 ++++++-- vllm/model_executor/models/mpt.py | 8 ++++++-- vllm/model_executor/models/nemotron.py | 8 ++++++-- vllm/model_executor/models/olmo.py | 8 ++++++-- vllm/model_executor/models/olmoe.py | 8 ++++++-- vllm/model_executor/models/opt.py | 8 ++++++-- vllm/model_executor/models/orion.py | 8 ++++++-- vllm/model_executor/models/paligemma.py | 7 ++++--- vllm/model_executor/models/persimmon.py | 8 ++++++-- vllm/model_executor/models/phi.py | 8 ++++++-- vllm/model_executor/models/phi3_small.py | 8 ++++++-- vllm/model_executor/models/phi3v.py | 9 ++++++--- vllm/model_executor/models/phimoe.py | 8 ++++++-- vllm/model_executor/models/pixtral.py | 12 ++++++++---- vllm/model_executor/models/qwen.py | 8 ++++++-- vllm/model_executor/models/qwen2.py | 18 ++++++++++++------ vllm/model_executor/models/qwen2_audio.py | 9 +++++++-- vllm/model_executor/models/qwen2_cls.py | 7 ++++--- vllm/model_executor/models/qwen2_moe.py | 8 ++++++-- vllm/model_executor/models/qwen2_rm.py | 7 ++++--- vllm/model_executor/models/qwen2_vl.py | 8 ++++++-- vllm/model_executor/models/siglip.py | 11 ++++++++--- vllm/model_executor/models/solar.py | 9 +++++++-- vllm/model_executor/models/stablelm.py | 8 ++++++-- vllm/model_executor/models/starcoder2.py | 8 ++++++-- vllm/model_executor/models/ultravox.py | 7 ++++--- vllm/model_executor/models/utils.py | 11 ++++++----- vllm/model_executor/models/xverse.py | 8 ++++++-- 74 files changed, 454 insertions(+), 185 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0f8b81c3ef40..d9ce85949e4e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -334,7 +334,17 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: with target_device: model = _initialize_model(vllm_config=vllm_config) - model.load_weights(self._get_all_weights(model_config, model)) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self._get_all_weights(model_config, model)) + # We only enable strict check for non-quantiized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index d52418ee0f6f..e58ad19cab54 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,5 +1,5 @@ """Inference-only Snowflake Arctic model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -480,7 +480,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -518,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("ws", f"experts.{expert_id}.w3.weight", expert_id)) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() logger.info( "It will take ~10 minutes loading from the 16-bit weights. " @@ -573,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 01ce7c42cd39..3749a16a3899 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -18,7 +18,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -404,13 +404,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -449,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class BaichuanForCausalLM(BaiChuanBaseForCausalLM): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 42dd6119e76f..d8301a36acb0 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -337,7 +337,8 @@ def forward( return self.encoder(hidden_states, kv_caches, attn_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "pooler" in name: continue @@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class BertEmbeddingModel(nn.Module): diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e61201067736..6db6462e97f3 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -415,7 +415,8 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.post_layernorm(hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -423,6 +424,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.encoder.layers) for name, loaded_weight in weights: @@ -440,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 03dc1d15ab69..7d7639b4a92c 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -692,6 +692,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index cf2eee817276..1060d418474e 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -16,7 +16,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -341,8 +341,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if name == "lm_head.weight": continue @@ -371,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 7b59c818e0b6..8f91abffaea9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -1034,7 +1034,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 70e9b607b064..81e56381eabd 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -3,7 +3,8 @@ """Inference-only ChatGLM model compatible with THUDM weights.""" from argparse import Namespace from array import array -from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict +from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, + TypedDict) import torch from PIL import Image @@ -645,7 +646,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: # Merge two ColumnParallelLinear into one MergedColumnParallelLinear merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { "transformer.vision.linear_proj.merged_proj.weight": { @@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): } params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: is_weight_to_be_merge = False for _, merged_weight_dict in merged_weights_dict.items(): @@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) for combined_name, merged_weight_dict in merged_weights_dict.items(): if combined_name in params_dict: @@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, combined_weight) + loaded_params.add(combined_name) + return loaded_params diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 2d81b9266826..184758f4a8a4 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -483,7 +483,8 @@ def device(self): # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -491,6 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: @@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index fbb09a64cde9..9fd083e5a02a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -402,7 +402,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -447,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 3952ff31e5ce..eab338800249 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -417,13 +417,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: expert_params_mapping = [( "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", f"mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: @@ -447,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index b38fd9fa49c2..c551853956b9 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Iterable, Tuple +from typing import Iterable, Set, Tuple import torch @@ -57,7 +57,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): delattr(config, "num_key_value_heads_per_layer") super().__init__(vllm_config=vllm_config) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -67,6 +68,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -97,6 +99,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: hidden_size = self.config.hidden_size diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 36dfea5a6565..8c5ad9904e92 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -442,7 +442,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -453,6 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -487,3 +489,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1e32fe60c7a5..d2c4ca0bf85e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2 model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -550,7 +550,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -566,6 +567,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -623,3 +625,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 52dd603ca558..9d739d047954 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Exaone model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -513,7 +513,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -523,6 +524,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".c_fc_1", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -543,6 +545,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -576,6 +579,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index e97abe949ccd..2aa4b67d9989 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -18,7 +18,7 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -473,7 +473,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -483,6 +484,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if name == "lm_head.weight" and self.tie_word_embeddings: # Falcon uses tied embeddings except Falcon-11b. @@ -519,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 971a71180164..d3a9ff6915b8 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -156,7 +156,8 @@ def sample(self, logits: torch.Tensor, next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -165,12 +166,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -183,6 +185,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Florence2ForConditionalGeneration(nn.Module): @@ -248,10 +252,11 @@ def sample( ) -> SamplerOutput: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: skip_prefixes = [ 'image_projection', "vision_tower", "image_proj_norm", "image_pos_embed", "visual_temporal_embed" ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 31fc098a8bb3..7b46907ac83a 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,7 +16,8 @@ """ PyTorch Fuyu model.""" import math from array import array -from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict) import torch import torch.nn as nn @@ -354,6 +355,7 @@ def sample( next_tokens = self.language_model.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index ace13664c6ea..64e03b30bf2f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -424,7 +424,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -469,3 +470,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params) + return loaded_params diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index a60b4e73a76d..4ba39223cc07 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -312,7 +312,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -354,6 +355,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params) + return loaded_params class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -451,13 +453,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights(weights) + return loader.load_weights(weights) class Gemma2EmbeddingModel(nn.Module, SupportsPP): diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fa0fdad28d16..1c61408ae1dd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -298,8 +298,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final @@ -328,3 +330,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b2fc79d0d36d..50a143cb1b60 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -323,8 +323,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: continue @@ -344,3 +346,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, 'v') else: weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index cec3fd12a67d..d5defc60764e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -291,7 +291,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -301,6 +302,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue @@ -330,3 +332,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 11f286d6bcba..0bb5e2f9b95f 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -303,8 +303,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): @@ -337,3 +339,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index cb2583e69d88..c1e2e87f08ec 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -455,7 +455,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -465,6 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -485,6 +487,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -518,6 +521,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index f437dd521a7d..a91a18816995 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -419,7 +419,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: new_weights = {} for n, p in weights: if n.endswith('.block_sparse_moe.input_linear.weight'): @@ -452,4 +453,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): pass else: new_weights[n] = p - mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) + return mixtral.MixtralForCausalLM.load_weights(self, + new_weights.items()) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index b21bc2a3f9ce..16192928beb1 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn @@ -331,7 +331,8 @@ def forward( last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -339,11 +340,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -352,3 +355,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 0cecc754e916..5d176b2a4e41 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -15,7 +15,7 @@ import math from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple, - Optional, Tuple, TypedDict, Union) + Optional, Set, Tuple, TypedDict, Union) import torch import torch.utils.checkpoint @@ -751,9 +751,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 9761635d2a6c..bd91a0806ae5 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -5,7 +5,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from functools import partial -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn @@ -469,10 +469,14 @@ def forward( return encoder_outputs - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 19bfe16e4d5f..94b819b5d936 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -369,13 +369,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 92579e3aae94..7ea2f9be2191 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -6,7 +6,7 @@ # -------------------------------------------------------- import re from functools import cached_property, partial -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -663,6 +663,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ee49ffb3cd87..41db85b67845 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -19,7 +19,7 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -350,8 +350,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final @@ -382,3 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5612dd688638..f83f0fce7275 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,5 @@ """Inference-only Jamba model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -462,7 +462,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -479,6 +480,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -534,6 +536,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def _is_moe_layer(name: str): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e53631ef19f3..2b40e9ec73fa 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -350,7 +350,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -360,6 +361,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -375,6 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -390,7 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - break else: # Skip loading extra bias for GPTQ models. @@ -408,6 +410,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -577,13 +581,14 @@ def sample(self, logits: torch.Tensor, next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights( + return loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b13bcfa67681..e7d3161a7cb2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, +from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) import torch @@ -547,6 +547,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index dd2fa6cac969..37e2227a52dc 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -654,6 +654,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 5d5598d07bfd..e2880c76cf43 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,6 +1,6 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -445,10 +445,11 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, # This model doesn't support images for now ignore_unexpected_prefixes=["image_newline"], ) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index a5b210817783..705ca1e4ab6e 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,6 +1,6 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -887,6 +887,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ac0d265a961f..405b8f7787ba 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -243,8 +243,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -256,3 +258,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b05360b55466..b4ed6538bdda 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -156,8 +156,10 @@ def generate_proposals( sampling_metadata=sampling_metadata, ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() weights_map = {} @@ -181,9 +183,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) assert (self.truncated_vocab_size == self.orig_vocab_size) or (self.token_map is not None) + + return loaded_params diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 6b67266c5336..b92bff4d7c28 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -21,7 +21,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -539,7 +539,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -556,6 +557,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -606,3 +608,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index fd8eda997f76..99bf1d42d035 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,7 @@ import re from functools import partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict, Union) + Set, Tuple, TypedDict, Union) import torch import torch.types @@ -602,7 +602,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -612,6 +613,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: @@ -630,10 +632,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if is_pp_missing_parameter( - name.replace(weight_name, param_name), self): + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): continue - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -646,6 +648,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index eebf5bab5a28..0faffb4f1b00 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -404,7 +404,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -421,6 +422,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -478,3 +480,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index af2e9586988d..ddd6afcf6a1b 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -409,7 +409,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -418,6 +419,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -448,3 +450,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index db7ee7b2d853..41f62b37f3bd 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -13,7 +13,7 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -1427,7 +1427,8 @@ def forward( return outputs - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1437,7 +1438,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - updated_params = set() + updated_params: Set[str] = set() for name, loaded_weight in weights: if 'patch_embedding.weight' in name: name = name.replace('patch_embedding.weight', @@ -1457,6 +1458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 4d7e82880041..f2aa2653c4f5 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Tuple +from typing import Iterable, List, Set, Tuple import torch import torch.nn as nn @@ -188,11 +188,15 @@ def generate_proposals( return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: param = params_dict.get(name.replace("speculator.", "")) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 3c74ef2448ab..8716e92b0f1c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,6 +1,6 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -324,8 +324,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -336,3 +338,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index eb45beae7d21..ceab299a7950 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -474,7 +474,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -482,6 +483,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".qkv_proj", ".v_proj", "v"), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -522,3 +524,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 98d4e1ec320a..dc138e2e636a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -356,7 +356,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -366,6 +367,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index f4eebab8c98d..ab87695d8e65 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -364,7 +364,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -383,6 +384,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -455,3 +457,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 997fe642439e..db85a494980a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -394,7 +394,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -402,6 +403,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name and self.config.tie_word_embeddings: continue @@ -431,3 +433,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 39d659c49cbc..b01734af8ddd 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -3,7 +3,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -327,7 +327,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -337,6 +338,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -368,3 +370,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index eea229359255..dd5256eb87ab 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,4 @@ -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -295,6 +295,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 62c509153a11..3b8199f4f166 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -324,8 +324,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -358,3 +360,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index a2ab0d74c48d..0a117bf16c9b 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -34,7 +34,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -345,7 +345,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -353,6 +354,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v") ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -383,3 +385,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 2139cec44180..a78e4d355a31 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -457,9 +457,11 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -471,3 +473,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4db65edc174f..2e583bb08e87 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,7 +15,7 @@ import itertools import re from functools import cached_property, lru_cache -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -744,7 +744,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -759,5 +760,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The HF config doesn't specify whether these are tied, # so we detect it this way - if "embed_tokens" not in autoloaded_weights: + if "embed_tokens.weight" not in autoloaded_weights: self.embed_tokens = self.language_model.model.embed_tokens + autoloaded_weights.add("embed_tokens.weight") + return autoloaded_weights diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index b7e70f8fa2c6..e475d286bd7e 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -598,7 +598,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -613,6 +614,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -666,3 +668,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 790a260d43ec..d44a538d56b8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, fields from functools import cached_property from itertools import tee -from typing import Iterable, List, Mapping, Optional, Tuple, Union +from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union import numpy import torch @@ -1067,7 +1067,8 @@ def forward( # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1077,6 +1078,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.transformer.layers) for name, loaded_weight in weights: @@ -1089,8 +1091,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -1099,3 +1101,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 447632cefcd9..3978c176a214 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,7 +8,7 @@ import re from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Tuple, TypedDict, Union) + Optional, Set, Tuple, TypedDict, Union) import numpy as np import torch @@ -964,13 +964,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -999,6 +1001,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class QWenLLM(QWenBaseModel): diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8f10df808c21..370cff5fa153 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -332,7 +332,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -342,6 +343,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -372,6 +374,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -494,13 +498,14 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights(weights) + return loader.load_weights(weights) class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): @@ -564,7 +569,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index d30950361ad8..a4965f34b1ca 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -20,7 +20,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import librosa import numpy as np @@ -420,7 +421,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -430,6 +432,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -463,3 +466,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 07eb330620a4..dc5dabf6fc38 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -4,7 +4,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-Classification model compatible with HF weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -97,7 +97,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 249d94b5d95e..96a9bc451f4d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn.functional as F @@ -436,7 +436,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -455,6 +456,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -532,3 +534,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 6db467af334f..988d682d36be 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -3,7 +3,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -110,7 +110,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2335baf45977..ef6b52db6e17 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -23,7 +23,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Tuple, Type, TypedDict, Union) + Optional, Set, Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -1333,7 +1333,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -1343,6 +1344,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "gate_proj", 0), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -1392,3 +1394,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index acaf4afdecfe..c9e09b879843 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,7 +2,7 @@ within a vision language model.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -594,7 +594,8 @@ def forward( interpolate_pos_encoding=interpolate_pos_encoding, ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -602,6 +603,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: @@ -619,8 +621,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -629,3 +632,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index affb2c975ce4..6d6fafc5ab0e 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -21,7 +21,7 @@ # limitations under the License. """Inference-only Solar model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -477,7 +477,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -487,6 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -502,6 +504,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -535,6 +538,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 99acce596602..e11d2e916730 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -18,7 +18,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -306,7 +306,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -316,6 +317,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -347,3 +349,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 0ef940acebb9..74c66042226d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -314,7 +314,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -323,6 +324,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -346,3 +348,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9fde22c016de..512adbc7db35 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,7 @@ import math from functools import cached_property, lru_cache -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union, cast) import numpy as np @@ -504,10 +504,11 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) - loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1d51885f9094..7a4fcce95603 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Protocol, Tuple, Union, overload) + Optional, Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -172,8 +172,9 @@ def _load_module( if module != self.module: module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): - module_load_weights(weights) - return + loaded_params = module_load_weights(weights) + yield from map(lambda x: self._get_qualname(base_prefix, x), + loaded_params) child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) @@ -222,11 +223,11 @@ def load_weights( weights: Iterable[Tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, - ) -> List[str]: + ) -> Set[str]: if mapper is not None: weights = mapper.apply(weights) - autoloaded_weights = list(self._load_module("", self.module, weights)) + autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 51172d8782a7..bc37a997eabb 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -376,7 +376,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -385,6 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if ("rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name @@ -413,3 +415,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params From 466b2cfb7468436882556d5bff77672534b4d480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Mon, 18 Nov 2024 05:29:26 +0200 Subject: [PATCH 17/23] [Bugfix] Ignore ray reinit error when current platform is ROCm or XPU (#10375) Signed-off-by: Hollow Man Signed-off-by: Linkun Chen --- vllm/executor/ray_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 41dd59bc65ec..4f28efd63908 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -234,7 +234,7 @@ def initialize_ray_cluster( if current_platform.is_rocm() or current_platform.is_xpu(): # Try to connect existing ray instance and create a new one if not found try: - ray.init("auto") + ray.init("auto", ignore_reinit_error=True) except ConnectionError: logger.warning( "No existing RAY instance detected. " From 3f092cef27d42ce6f3c8e4c5db9b125ebbfd8638 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 18/23] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 32160a8c0432..25d585d65285 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -96,7 +96,6 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[List[int]], - multi_modal_placeholders: MultiModalPlaceholderDict, prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -105,11 +104,13 @@ def __init__( encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, num_cached_tokens: Optional[int] = None, + *, + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.multi_modal_placeholders = multi_modal_placeholders + self.multi_modal_placeholders = multi_modal_placeholders or {} self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -144,7 +145,6 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - multi_modal_placeholders={}, prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, @@ -158,8 +158,7 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -202,8 +201,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) + output_text = seq.get_output_text_to_return(text_buffer_length, + delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -280,17 +279,17 @@ def from_seq_group( seq_group.set_finished_time(finished_time) init_args = (seq_group.request_id, prompt, prompt_token_ids, - seq_group.multi_modal_placeholders, prompt_logprobs, - outputs, finished, seq_group.metrics, + prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) + init_kwargs = {"multi_modal_placeholders": seq_group.multi_modal_placeholders} if use_cache: request_output = seq_group.cached_request_output - request_output.__init__(*init_args) # type: ignore + request_output.__init__(*init_args, **init_kwargs) # type: ignore else: - request_output = cls(*init_args) + request_output = cls(*init_args, **init_kwargs) return request_output @@ -298,7 +297,6 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"multi_modal_placeholders={self.multi_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " @@ -306,7 +304,8 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens})") + f"num_cached_tokens={self.num_cached_tokens}, " + f"multi_modal_placeholders={self.multi_modal_placeholders})") class EmbeddingRequestOutput: From dd8427ef342dfa52d6c252db014b414e9e2f5e04 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 19/23] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 25d585d65285..0e307d337bf1 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -282,7 +282,9 @@ def from_seq_group( prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) - init_kwargs = {"multi_modal_placeholders": seq_group.multi_modal_placeholders} + init_kwargs = { + "multi_modal_placeholders": seq_group.multi_modal_placeholders + } if use_cache: request_output = seq_group.cached_request_output From 76ac8b0157de48857602fcbe7406e57724e814dd Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 20/23] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 0e307d337bf1..c56aa0478925 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -278,20 +278,26 @@ def from_seq_group( finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - init_args = (seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished, seq_group.metrics, - seq_group.lora_request, encoder_prompt, - encoder_prompt_token_ids, num_cached_tokens) init_kwargs = { + "request_id": seq_group.request_id, + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "prompt_logprobs": prompt_logprobs, + "outputs": outputs, + "finished": finished, + "metrics": seq_group.metrics, + "lora_request": seq_group.lora_request, + "encoder_prompt": encoder_prompt, + "encoder_prompt_token_ids": encoder_prompt_token_ids, + "num_cached_tokens": num_cached_tokens, "multi_modal_placeholders": seq_group.multi_modal_placeholders } if use_cache: request_output = seq_group.cached_request_output - request_output.__init__(*init_args, **init_kwargs) # type: ignore - + request_output.__init__(**init_kwargs) # type: ignore else: - request_output = cls(*init_args, **init_kwargs) + request_output = cls(**init_kwargs) return request_output From c963a256f43603194310c27aacccaa555cb350e6 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:59:58 +0000 Subject: [PATCH 21/23] disable mypy type check mypy is not smart enough to validate kwargs Signed-off-by: Linkun Chen --- vllm/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index c56aa0478925..80cf4702934f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -297,7 +297,7 @@ def from_seq_group( request_output = seq_group.cached_request_output request_output.__init__(**init_kwargs) # type: ignore else: - request_output = cls(**init_kwargs) + request_output = cls(**init_kwargs) # type: ignore return request_output From 470fbd3921d22a53cac1662f26ca207f1251ca8e Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:59:58 +0000 Subject: [PATCH 22/23] disable mypy type check mypy is not smart enough to validate kwargs Signed-off-by: Linkun Chen --- vllm/outputs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 80cf4702934f..4ae9b377ae69 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -158,7 +158,8 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -201,8 +202,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return(text_buffer_length, - delta) + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, From 550be230dfa4ba7b05f450a053bb8256287885e5 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 06:17:27 +0000 Subject: [PATCH 23/23] remove unnecessary debug code Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 4edfd6862511..6233860747b9 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -17,7 +17,6 @@ from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, TextPrompt, TokensPrompt) -from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs @@ -28,8 +27,6 @@ if TYPE_CHECKING: from _typeshed import StrPath -logger = init_logger(__name__) - MODELS = ["mistralai/Pixtral-12B-2409"] IMG_URLS = [ "https://picsum.photos/id/237/400/300",