Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def iter_params(self, model_id: str):
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama
Expand Down Expand Up @@ -249,7 +249,7 @@ def iter_params(self, model_id: str):
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"ibm/PowerLM-3b",
"deepseek-ai/DeepSeek-V2-Lite-Chat",
# [LANGUAGE EMBEDDING]
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..utils import create_new_process_for_each_test, multi_gpu_test

MODEL_PATH = "ArthurZ/ilama-3.2-1B"
MODEL_PATH = "hmellor/Ilama-3.2-1B"

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

Expand Down
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def check_available_online(
}

_TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_implementation(
"model,model_impl",
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
]) # trust_remote_code=True by default
def test_models(
hf_runner: type[HfRunner],
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,9 @@ def __new__(cls, *args, **kwargs) -> Self:
instance.quant_config = quant_config

# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)

Expand Down
98 changes: 35 additions & 63 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def __exit__(self, exc_type, exc_value, traceback):
setattr(self.config, key, value)


class TransformersModel(nn.Module):
class TransformersModel:

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -452,9 +452,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
Expand Down Expand Up @@ -618,9 +615,6 @@ def init_parameters(self, module: nn.Module):
for child in module.children():
self.init_parameters(child)

def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down Expand Up @@ -692,7 +686,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.config = config

self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
Expand All @@ -714,22 +710,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -738,8 +719,9 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
Expand All @@ -753,12 +735,10 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)


@MULTIMODAL_REGISTRY.register_processor(
Expand All @@ -770,6 +750,20 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]

# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"text_model.model": "text_model",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# deal with Qwen2-VL mapping
"model.layers": "language_model.layers",
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The updated hf_to_vllm_mapper for TransformersForMultimodalLM appears to be incorrect. The AutoWeightsLoader is initialized with self (the TransformersForMultimodalLM instance), which has self.model as an attribute containing the PreTrainedModel. Therefore, parameter names within the model are expected to be prefixed with model. (e.g., model.language_model...).

The current mappings, such as "language_model.model": "language_model", will cause the loader to look for a top-level language_model attribute on TransformersForMultimodalLM, which doesn't exist. The target prefixes should include model. to correctly map to the nested structure.

For example, language_model.model from the checkpoint should map to model.language_model in the vLLM model.

hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "text_model.lm_head": "model.lm_head",
            "language_model.lm_head": "model.lm_head",
            # deal with Qwen2-VL mapping
            "model.layers": "model.language_model.layers",
        })


def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
Expand All @@ -778,7 +772,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.dtype = vllm_config.model_config.dtype

self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()

if get_pp_group().is_last_rank:
Expand All @@ -801,32 +797,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"

return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -846,8 +817,9 @@ def forward(
input_ids, multimodal_embeds)
input_ids = None

model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
Expand Down Expand Up @@ -896,7 +868,7 @@ def get_multimodal_embeddings(self, **kwargs):
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)

vision_embeddings = self.model.model.get_image_features(
vision_embeddings = self.model.get_image_features(
pixel_values,
**{
k: v.flatten(0, 1)
Expand Down Expand Up @@ -926,7 +898,7 @@ def get_input_embeddings(
input_ids: torch.Tensor,
multimodal_embeddings=None,
) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)
Expand Down
2 changes: 1 addition & 1 deletion vllm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"allenai/OLMoE-1B-7B-0924-Instruct",
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"BAAI/bge-reranker-v2-m3",
Expand Down