Skip to content

Commit d2432e5

Browse files
hmellorPradyun Ramadorai
authored andcommitted
Simplify weight loading in Transformers backend (vllm-project#21382)
Signed-off-by: Harry Mellor <[email protected]>
1 parent aebeb75 commit d2432e5

File tree

7 files changed

+53
-76
lines changed

7 files changed

+53
-76
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def iter_params(self, model_id: str):
177177
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
178178
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
179179
# Tests TransformersForCausalLM
180-
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
180+
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
181181
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
182182
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
183183
# Uses Llama
@@ -249,7 +249,7 @@ def iter_params(self, model_id: str):
249249
# [LANGUAGE GENERATION]
250250
"microsoft/Phi-3.5-MoE-instruct",
251251
"meta-llama/Llama-3.2-1B-Instruct",
252-
"ArthurZ/Ilama-3.2-1B",
252+
"hmellor/Ilama-3.2-1B",
253253
"ibm/PowerLM-3b",
254254
"deepseek-ai/DeepSeek-V2-Lite-Chat",
255255
# [LANGUAGE EMBEDDING]

tests/lora/test_transformers_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ..utils import create_new_process_for_each_test, multi_gpu_test
1111

12-
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
12+
MODEL_PATH = "hmellor/Ilama-3.2-1B"
1313

1414
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
1515

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def check_available_online(
500500
}
501501

502502
_TRANSFORMERS_MODELS = {
503-
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
503+
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
504504
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
505505
}
506506

tests/models/test_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def check_implementation(
5656
"model,model_impl",
5757
[
5858
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
59-
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
59+
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
6060
]) # trust_remote_code=True by default
6161
def test_models(
6262
hf_runner: type[HfRunner],

vllm/model_executor/models/interfaces.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,9 @@ def __new__(cls, *args, **kwargs) -> Self:
624624
instance.quant_config = quant_config
625625

626626
# apply model mappings to config for proper config-model matching
627-
# NOTE: `TransformersForCausalLM` is not supported due to how this
628-
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
629-
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
630-
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
631-
instance.quant_config.apply_vllm_mapper(
632-
instance.hf_to_vllm_mapper)
633-
if getattr(instance, "packed_modules_mapping", None) is not None:
627+
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
628+
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
629+
if instance.packed_modules_mapping is not None:
634630
instance.quant_config.packed_modules_mapping.update(
635631
instance.packed_modules_mapping)
636632

vllm/model_executor/models/transformers.py

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def __exit__(self, exc_type, exc_value, traceback):
414414
setattr(self.config, key, value)
415415

416416

417-
class TransformersModel(nn.Module):
417+
class TransformersModel:
418418

419419
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
420420
super().__init__()
@@ -454,9 +454,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
454454
# method after v4.54.0 is released
455455
self.text_config._attn_implementation = "vllm"
456456
with init_on_device_without_buffers("meta"), config_override:
457-
# FIXME(Isotr0py): We need to refactor this part in the future to
458-
# avoid registering an extra model layer, otherwise we will need a
459-
# weights mapper to rename weights.
460457
self.model: PreTrainedModel = AutoModel.from_config(
461458
config,
462459
torch_dtype=model_config.dtype,
@@ -620,9 +617,6 @@ def init_parameters(self, module: nn.Module):
620617
for child in module.children():
621618
self.init_parameters(child)
622619

623-
def get_input_embeddings(self) -> nn.Module:
624-
return self.model.get_input_embeddings()
625-
626620
def forward(
627621
self,
628622
input_ids: Optional[torch.Tensor],
@@ -694,7 +688,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
694688

695689
self.config = config
696690

697-
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
691+
self.transformers_model = TransformersModel(vllm_config=vllm_config,
692+
prefix=prefix)
693+
self.model = self.transformers_model.model
698694

699695
if get_pp_group().is_last_rank:
700696
self.unpadded_vocab_size = config.vocab_size
@@ -716,22 +712,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
716712
self.lm_head = PPMissingLayer()
717713

718714
self.make_empty_intermediate_tensors = (
719-
self.model.make_empty_intermediate_tensors)
720-
721-
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
722-
# this makes thing complicated. We need to remove this mapper after refactor
723-
# `TransformersModel` in the future.
724-
# NOTE: `SupportsQuant` can be updated after property decorator is removed
725-
@property
726-
def hf_to_vllm_mapper(self):
727-
prefix_mapper = {
728-
name: "model." + name
729-
for name, _ in self.model.model.named_children()
730-
}
731-
return WeightsMapper(
732-
orig_to_new_substr={"model.": "model.model."},
733-
orig_to_new_prefix=prefix_mapper,
734-
)
715+
self.transformers_model.make_empty_intermediate_tensors)
735716

736717
def forward(
737718
self,
@@ -740,8 +721,9 @@ def forward(
740721
intermediate_tensors: Optional[IntermediateTensors] = None,
741722
inputs_embeds: Optional[torch.Tensor] = None,
742723
) -> Union[torch.Tensor, IntermediateTensors]:
743-
model_output = self.model(input_ids, positions, intermediate_tensors,
744-
inputs_embeds)
724+
model_output = self.transformers_model.forward(input_ids, positions,
725+
intermediate_tensors,
726+
inputs_embeds)
745727
return model_output
746728

747729
def compute_logits(
@@ -755,12 +737,10 @@ def compute_logits(
755737

756738
def load_weights(self, weights: Iterable[tuple[str,
757739
torch.Tensor]]) -> set[str]:
758-
loader = AutoWeightsLoader(
759-
self,
760-
skip_prefixes=(["lm_head."]
761-
if self.config.tie_word_embeddings else None),
762-
)
763-
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
740+
skip_prefixes = ["lm_head."
741+
] if self.config.tie_word_embeddings else None
742+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
743+
return loader.load_weights(weights)
764744

765745

766746
@MULTIMODAL_REGISTRY.register_processor(
@@ -772,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
772752
embedding_padding_modules = ["lm_head"]
773753
embedding_modules = ["embed_tokens"]
774754

755+
# Backwards compatibility for prev released models. State dicts back then
756+
# had different formats and cannot be loaded with `AutoModel` mapping as is
757+
hf_to_vllm_mapper = WeightsMapper(
758+
orig_to_new_prefix={
759+
"language_model.model": "model.language_model",
760+
"text_model.model": "model.text_model",
761+
"vision_tower": "model.vision_tower",
762+
"vqmodel": "model.vqmodel",
763+
"visual": "model.visual",
764+
"vision_model": "model.vision_model",
765+
"vision_embed_tokens": "model.vision_embed_tokens",
766+
"image_newline": "model.image_newline",
767+
"multi_modal_projector": "model.multi_modal_projector",
768+
"text_model.lm_head": "lm_head",
769+
"language_model.lm_head": "lm_head",
770+
# Qwen models used "model" as the name for the language model.
771+
# Therefore, we must map each of submodule explicitly to avoid
772+
# conflicts with newer models that use "model.language_model".
773+
"model.embed_tokens": "model.language_model.embed_tokens",
774+
"model.layers": "model.language_model.layers",
775+
"model.norm": "model.language_model.norm",
776+
})
777+
775778
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
776779
super().__init__()
777780
config: PretrainedConfig = vllm_config.model_config.hf_config
@@ -780,7 +783,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
780783
self.config = config
781784
self.dtype = vllm_config.model_config.dtype
782785

783-
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
786+
self.transformers_model = TransformersModel(vllm_config=vllm_config,
787+
prefix=prefix)
788+
self.model = self.transformers_model.model
784789
text_config = config.get_text_config()
785790

786791
if get_pp_group().is_last_rank:
@@ -803,32 +808,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
803808
self.lm_head = PPMissingLayer()
804809

805810
self.make_empty_intermediate_tensors = (
806-
self.model.make_empty_intermediate_tensors)
807-
808-
@property
809-
def hf_to_vllm_mapper(self):
810-
# Backwards compatibility for prev released models
811-
# State dicts back then had different formats
812-
# and cannot be loaded with `AutoModel` mapping
813-
# as is
814-
prefix_mapper = {
815-
"language_model.model": "model.language_model",
816-
"text_model.model": "model.text_model",
817-
"vision_tower": "model.vision_tower",
818-
"vqmodel": "model.vqmodel",
819-
"vision_model": "model.vision_model",
820-
"vision_embed_tokens": "model.vision_embed_tokens",
821-
"image_newline": "model.image_newline",
822-
"multi_modal_projector": "model.multi_modal_projector",
823-
"text_model.lm_head": "lm_head",
824-
"language_model.lm_head": "lm_head",
825-
}
826-
# Don't change the order for QwenVL
827-
if 'Qwen2' in self.config.__class__.__name__:
828-
prefix_mapper["model"] = "model.language_model"
829-
prefix_mapper["visual"] = "model.visual"
830-
831-
return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
811+
self.transformers_model.make_empty_intermediate_tensors)
832812

833813
def forward(
834814
self,
@@ -848,8 +828,9 @@ def forward(
848828
input_ids, multimodal_embeds)
849829
input_ids = None
850830

851-
model_output = self.model(input_ids, positions, intermediate_tensors,
852-
inputs_embeds)
831+
model_output = self.transformers_model.forward(input_ids, positions,
832+
intermediate_tensors,
833+
inputs_embeds)
853834
return model_output
854835

855836
def compute_logits(
@@ -898,7 +879,7 @@ def get_multimodal_embeddings(self, **kwargs):
898879
if isinstance(num_image_patches, list):
899880
num_image_patches = torch.cat(num_image_patches)
900881

901-
vision_embeddings = self.model.model.get_image_features(
882+
vision_embeddings = self.model.get_image_features(
902883
pixel_values,
903884
**{
904885
k: v.flatten(0, 1)
@@ -928,7 +909,7 @@ def get_input_embeddings(
928909
input_ids: torch.Tensor,
929910
multimodal_embeddings=None,
930911
) -> torch.Tensor:
931-
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
912+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
932913
if (multimodal_embeddings is not None
933914
and len(multimodal_embeddings) != 0):
934915
mask = (input_ids == self.config.image_token_id)

vllm/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"allenai/OLMoE-1B-7B-0924-Instruct",
1111
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
1212
"AMead10/Llama-3.2-1B-Instruct-AWQ",
13-
"ArthurZ/Ilama-3.2-1B",
13+
"hmellor/Ilama-3.2-1B",
1414
"BAAI/bge-base-en-v1.5",
1515
"BAAI/bge-multilingual-gemma2",
1616
"BAAI/bge-reranker-v2-m3",

0 commit comments

Comments
 (0)