1515from transformers .utils import logging
1616
1717from vllm .config import VllmConfig
18- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
18+ from vllm .distributed import get_pp_group
1919from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , DummyData ,
2020 InputContext )
2121from vllm .inputs .data import TokenInputs , token_inputs
3434
3535from .interfaces import SupportsLoRA , SupportsMultiModal
3636from .phi4mm_audio import AudioEmbedding
37- from .utils import maybe_prefix
37+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
3838from .vision_siglip_navit import get_siglip_vision_model
3939
4040# <|endoftext10|> (see vocab.json in hf model)
@@ -352,12 +352,6 @@ def __init__(self,
352352 # n_embed or hidden_size
353353 hidden_size = config .n_embd if hasattr (
354354 config , 'n_embd' ) else config .hidden_size
355- if hasattr (config , 'embd_pdrop' ) or hasattr (config , 'embed_pdrop' ):
356- embd_drop = config .embd_pdrop if hasattr (
357- config , 'embd_pdrop' ) else config .embed_pdrop
358- self .drop = nn .Dropout (embd_drop )
359- else :
360- self .drop = None
361355
362356 # layer_idx to output the img features
363357 if isinstance (config .img_processor , dict ):
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
14311425 ],
14321426 }
14331427
1428+ hf_to_vllm_mapper = WeightsMapper (
1429+ orig_to_new_substr = {
1430+ "base_layer." : "" ,
1431+ },
1432+ orig_to_new_prefix = {
1433+ "model.embed_tokens_extend.audio_embed.audio_projection.vision." :
1434+ "embed_tokens_extend.audio_projection_for_vision." ,
1435+ "model.embed_tokens_extend.audio_embed.audio_projection.speech." :
1436+ "embed_tokens_extend.audio_projection." ,
1437+ "model.embed_tokens_extend.audio_embed." : "embed_tokens_extend." ,
1438+ "model.embed_tokens_extend.image_embed." : "vision_encoder." ,
1439+ },
1440+ )
1441+
14341442 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
14351443 super ().__init__ ()
14361444 config = vllm_config .model_config .hf_config
@@ -1445,8 +1453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14451453 self .lora_config = lora_config
14461454
14471455 # Tensor/Pipeline parallel not supported for now.
1448- assert get_tensor_model_parallel_world_size (
1449- ) == 1 , "tensor parallel is not supported"
14501456 assert get_pp_group (
14511457 ).world_size == 1 , "pipeline parallel is not supported"
14521458
@@ -1686,44 +1692,6 @@ def merge_image_features_to_inputs_embeds(
16861692 )
16871693 return merged_embeds
16881694
1689- def load_weights (self , weights : Iterable [Tuple [str ,
1690- torch .Tensor ]]) -> None :
1691- weights = {name : weight for name , weight in weights }
1692- adjusted_weights = {}
1693-
1694- for name , weight in weights .items ():
1695- # NOTE vision-speech tasks use a separate projection layer
1696- audio_proj_4v = \
1697- "model.embed_tokens_extend.audio_embed.audio_projection.vision"
1698- if name .startswith (audio_proj_4v ):
1699- name = name .replace (
1700- audio_proj_4v ,
1701- "embed_tokens_extend.audio_projection_for_vision" )
1702-
1703- name = (name .replace (
1704- "model.embed_tokens_extend.audio_embed." \
1705- "audio_projection.speech." ,
1706- "embed_tokens_extend.audio_projection." ,
1707- ).replace (
1708- "model.embed_tokens_extend.audio_embed." ,
1709- "embed_tokens_extend." ,
1710- ).replace ("model.embed_tokens_extend.image_embed." ,
1711- "vision_encoder." ))
1712- # NOTE: this is deal with LoRA injection, where `base_layer`
1713- # remains as the original layer in the model
1714- if name .endswith (".base_layer.weight" ):
1715- name = name .replace (".base_layer.weight" , ".weight" )
1716- adjusted_weights [name ] = weight
1717-
1718- missing_keys , unexpected_keys = self .load_state_dict (adjusted_weights ,
1719- strict = False )
1720- logger .debug ("*** missing keys:" )
1721- for key in missing_keys :
1722- logger .debug (key )
1723- logger .debug ("**** unexpected keys:" )
1724- for key in unexpected_keys :
1725- logger .debug (key )
1726-
17271695 def forward (
17281696 self ,
17291697 input_ids : torch .Tensor ,
@@ -1796,6 +1764,13 @@ def sample(
17961764 next_tokens = self .sampler (logits , sampling_metadata )
17971765 return next_tokens
17981766
1767+ def load_weights (self , weights : Iterable [Tuple [str ,
1768+ torch .Tensor ]]) -> None :
1769+ weights = ((name , data ) for name , data in weights
1770+ if "lora" not in name )
1771+ loader = AutoWeightsLoader (self )
1772+ return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
1773+
17991774 def get_mm_mapping (self ) -> MultiModelKeys :
18001775 """
18011776 Get the module prefix in multimodal models
@@ -1804,4 +1779,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
18041779 language_model = "model." ,
18051780 connector = ["audio_projection_for_vision" , "audio_projection" ],
18061781 tower_model = ["vision_encoder" , "embed_tokens_extend" ],
1807- )
1782+ )
0 commit comments