Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,9 @@ def __new__(cls, *args, **kwargs) -> Self:
instance.quant_config = quant_config

# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)

Expand Down
98 changes: 35 additions & 63 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def __exit__(self, exc_type, exc_value, traceback):
setattr(self.config, key, value)


class TransformersModel(nn.Module):
class TransformersModel:

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -452,9 +452,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
Expand Down Expand Up @@ -618,9 +615,6 @@ def init_parameters(self, module: nn.Module):
for child in module.children():
self.init_parameters(child)

def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down Expand Up @@ -692,7 +686,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.config = config

self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
Expand All @@ -714,22 +710,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -738,8 +719,9 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
Expand All @@ -753,12 +735,10 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)


@MULTIMODAL_REGISTRY.register_processor(
Expand All @@ -770,6 +750,20 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]

# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"text_model.model": "text_model",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# deal with Qwen2-VL mapping
"model.layers": "language_model.layers",
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The updated hf_to_vllm_mapper for TransformersForMultimodalLM appears to be incorrect. The AutoWeightsLoader is initialized with self (the TransformersForMultimodalLM instance), which has self.model as an attribute containing the PreTrainedModel. Therefore, parameter names within the model are expected to be prefixed with model. (e.g., model.language_model...).

The current mappings, such as "language_model.model": "language_model", will cause the loader to look for a top-level language_model attribute on TransformersForMultimodalLM, which doesn't exist. The target prefixes should include model. to correctly map to the nested structure.

For example, language_model.model from the checkpoint should map to model.language_model in the vLLM model.

hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "text_model.lm_head": "model.lm_head",
            "language_model.lm_head": "model.lm_head",
            # deal with Qwen2-VL mapping
            "model.layers": "model.language_model.layers",
        })


def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
Expand All @@ -778,7 +772,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.dtype = vllm_config.model_config.dtype

self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()

if get_pp_group().is_last_rank:
Expand All @@ -801,32 +797,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"

return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -846,8 +817,9 @@ def forward(
input_ids, multimodal_embeds)
input_ids = None

model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
Expand Down Expand Up @@ -896,7 +868,7 @@ def get_multimodal_embeddings(self, **kwargs):
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)

vision_embeddings = self.model.model.get_image_features(
vision_embeddings = self.model.get_image_features(
pixel_values,
**{
k: v.flatten(0, 1)
Expand Down Expand Up @@ -926,7 +898,7 @@ def get_input_embeddings(
input_ids: torch.Tensor,
multimodal_embeddings=None,
) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)
Expand Down