@@ -414,7 +414,7 @@ def __exit__(self, exc_type, exc_value, traceback):
414414 setattr (self .config , key , value )
415415
416416
417- class TransformersModel ( nn . Module ) :
417+ class TransformersModel :
418418
419419 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
420420 super ().__init__ ()
@@ -454,9 +454,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
454454 # method after v4.54.0 is released
455455 self .text_config ._attn_implementation = "vllm"
456456 with init_on_device_without_buffers ("meta" ), config_override :
457- # FIXME(Isotr0py): We need to refactor this part in the future to
458- # avoid registering an extra model layer, otherwise we will need a
459- # weights mapper to rename weights.
460457 self .model : PreTrainedModel = AutoModel .from_config (
461458 config ,
462459 torch_dtype = model_config .dtype ,
@@ -620,9 +617,6 @@ def init_parameters(self, module: nn.Module):
620617 for child in module .children ():
621618 self .init_parameters (child )
622619
623- def get_input_embeddings (self ) -> nn .Module :
624- return self .model .get_input_embeddings ()
625-
626620 def forward (
627621 self ,
628622 input_ids : Optional [torch .Tensor ],
@@ -694,7 +688,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
694688
695689 self .config = config
696690
697- self .model = TransformersModel (vllm_config = vllm_config , prefix = prefix )
691+ self .transformers_model = TransformersModel (vllm_config = vllm_config ,
692+ prefix = prefix )
693+ self .model = self .transformers_model .model
698694
699695 if get_pp_group ().is_last_rank :
700696 self .unpadded_vocab_size = config .vocab_size
@@ -716,22 +712,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
716712 self .lm_head = PPMissingLayer ()
717713
718714 self .make_empty_intermediate_tensors = (
719- self .model .make_empty_intermediate_tensors )
720-
721- # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
722- # this makes thing complicated. We need to remove this mapper after refactor
723- # `TransformersModel` in the future.
724- # NOTE: `SupportsQuant` can be updated after property decorator is removed
725- @property
726- def hf_to_vllm_mapper (self ):
727- prefix_mapper = {
728- name : "model." + name
729- for name , _ in self .model .model .named_children ()
730- }
731- return WeightsMapper (
732- orig_to_new_substr = {"model." : "model.model." },
733- orig_to_new_prefix = prefix_mapper ,
734- )
715+ self .transformers_model .make_empty_intermediate_tensors )
735716
736717 def forward (
737718 self ,
@@ -740,8 +721,9 @@ def forward(
740721 intermediate_tensors : Optional [IntermediateTensors ] = None ,
741722 inputs_embeds : Optional [torch .Tensor ] = None ,
742723 ) -> Union [torch .Tensor , IntermediateTensors ]:
743- model_output = self .model (input_ids , positions , intermediate_tensors ,
744- inputs_embeds )
724+ model_output = self .transformers_model .forward (input_ids , positions ,
725+ intermediate_tensors ,
726+ inputs_embeds )
745727 return model_output
746728
747729 def compute_logits (
@@ -755,12 +737,10 @@ def compute_logits(
755737
756738 def load_weights (self , weights : Iterable [tuple [str ,
757739 torch .Tensor ]]) -> set [str ]:
758- loader = AutoWeightsLoader (
759- self ,
760- skip_prefixes = (["lm_head." ]
761- if self .config .tie_word_embeddings else None ),
762- )
763- return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
740+ skip_prefixes = ["lm_head."
741+ ] if self .config .tie_word_embeddings else None
742+ loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
743+ return loader .load_weights (weights )
764744
765745
766746@MULTIMODAL_REGISTRY .register_processor (
@@ -772,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
772752 embedding_padding_modules = ["lm_head" ]
773753 embedding_modules = ["embed_tokens" ]
774754
755+ # Backwards compatibility for prev released models. State dicts back then
756+ # had different formats and cannot be loaded with `AutoModel` mapping as is
757+ hf_to_vllm_mapper = WeightsMapper (
758+ orig_to_new_prefix = {
759+ "language_model.model" : "model.language_model" ,
760+ "text_model.model" : "model.text_model" ,
761+ "vision_tower" : "model.vision_tower" ,
762+ "vqmodel" : "model.vqmodel" ,
763+ "visual" : "model.visual" ,
764+ "vision_model" : "model.vision_model" ,
765+ "vision_embed_tokens" : "model.vision_embed_tokens" ,
766+ "image_newline" : "model.image_newline" ,
767+ "multi_modal_projector" : "model.multi_modal_projector" ,
768+ "text_model.lm_head" : "lm_head" ,
769+ "language_model.lm_head" : "lm_head" ,
770+ # Qwen models used "model" as the name for the language model.
771+ # Therefore, we must map each of submodule explicitly to avoid
772+ # conflicts with newer models that use "model.language_model".
773+ "model.embed_tokens" : "model.language_model.embed_tokens" ,
774+ "model.layers" : "model.language_model.layers" ,
775+ "model.norm" : "model.language_model.norm" ,
776+ })
777+
775778 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
776779 super ().__init__ ()
777780 config : PretrainedConfig = vllm_config .model_config .hf_config
@@ -780,7 +783,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
780783 self .config = config
781784 self .dtype = vllm_config .model_config .dtype
782785
783- self .model = TransformersModel (vllm_config = vllm_config , prefix = prefix )
786+ self .transformers_model = TransformersModel (vllm_config = vllm_config ,
787+ prefix = prefix )
788+ self .model = self .transformers_model .model
784789 text_config = config .get_text_config ()
785790
786791 if get_pp_group ().is_last_rank :
@@ -803,32 +808,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
803808 self .lm_head = PPMissingLayer ()
804809
805810 self .make_empty_intermediate_tensors = (
806- self .model .make_empty_intermediate_tensors )
807-
808- @property
809- def hf_to_vllm_mapper (self ):
810- # Backwards compatibility for prev released models
811- # State dicts back then had different formats
812- # and cannot be loaded with `AutoModel` mapping
813- # as is
814- prefix_mapper = {
815- "language_model.model" : "model.language_model" ,
816- "text_model.model" : "model.text_model" ,
817- "vision_tower" : "model.vision_tower" ,
818- "vqmodel" : "model.vqmodel" ,
819- "vision_model" : "model.vision_model" ,
820- "vision_embed_tokens" : "model.vision_embed_tokens" ,
821- "image_newline" : "model.image_newline" ,
822- "multi_modal_projector" : "model.multi_modal_projector" ,
823- "text_model.lm_head" : "lm_head" ,
824- "language_model.lm_head" : "lm_head" ,
825- }
826- # Don't change the order for QwenVL
827- if 'Qwen2' in self .config .__class__ .__name__ :
828- prefix_mapper ["model" ] = "model.language_model"
829- prefix_mapper ["visual" ] = "model.visual"
830-
831- return WeightsMapper (orig_to_new_prefix = prefix_mapper , )
811+ self .transformers_model .make_empty_intermediate_tensors )
832812
833813 def forward (
834814 self ,
@@ -848,8 +828,9 @@ def forward(
848828 input_ids , multimodal_embeds )
849829 input_ids = None
850830
851- model_output = self .model (input_ids , positions , intermediate_tensors ,
852- inputs_embeds )
831+ model_output = self .transformers_model .forward (input_ids , positions ,
832+ intermediate_tensors ,
833+ inputs_embeds )
853834 return model_output
854835
855836 def compute_logits (
@@ -898,7 +879,7 @@ def get_multimodal_embeddings(self, **kwargs):
898879 if isinstance (num_image_patches , list ):
899880 num_image_patches = torch .cat (num_image_patches )
900881
901- vision_embeddings = self .model .model . get_image_features (
882+ vision_embeddings = self .model .get_image_features (
902883 pixel_values ,
903884 ** {
904885 k : v .flatten (0 , 1 )
@@ -928,7 +909,7 @@ def get_input_embeddings(
928909 input_ids : torch .Tensor ,
929910 multimodal_embeddings = None ,
930911 ) -> torch .Tensor :
931- inputs_embeds = self .model .model . get_input_embeddings ()(input_ids )
912+ inputs_embeds = self .model .get_input_embeddings ()(input_ids )
932913 if (multimodal_embeddings is not None
933914 and len (multimodal_embeddings ) != 0 ):
934915 mask = (input_ids == self .config .image_token_id )
0 commit comments