-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Simplify weight loading in Transformers backend #21382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -412,7 +412,7 @@ def __exit__(self, exc_type, exc_value, traceback): | |
| setattr(self.config, key, value) | ||
|
|
||
|
|
||
| class TransformersModel(nn.Module): | ||
| class TransformersModel: | ||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| super().__init__() | ||
|
|
@@ -452,9 +452,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| # method after v4.54.0 is released | ||
| self.text_config._attn_implementation = "vllm" | ||
| with init_on_device_without_buffers("meta"), config_override: | ||
| # FIXME(Isotr0py): We need to refactor this part in the future to | ||
| # avoid registering an extra model layer, otherwise we will need a | ||
| # weights mapper to rename weights. | ||
| self.model: PreTrainedModel = AutoModel.from_config( | ||
| config, | ||
| torch_dtype=model_config.dtype, | ||
|
|
@@ -618,9 +615,6 @@ def init_parameters(self, module: nn.Module): | |
| for child in module.children(): | ||
| self.init_parameters(child) | ||
|
|
||
| def get_input_embeddings(self) -> nn.Module: | ||
| return self.model.get_input_embeddings() | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: Optional[torch.Tensor], | ||
|
|
@@ -692,7 +686,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
|
|
||
| self.config = config | ||
|
|
||
| self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) | ||
| self.transformers_model = TransformersModel(vllm_config=vllm_config, | ||
| prefix=prefix) | ||
| self.model = self.transformers_model.model | ||
|
|
||
| if get_pp_group().is_last_rank: | ||
| self.unpadded_vocab_size = config.vocab_size | ||
|
|
@@ -714,22 +710,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.lm_head = PPMissingLayer() | ||
|
|
||
| self.make_empty_intermediate_tensors = ( | ||
| self.model.make_empty_intermediate_tensors) | ||
|
|
||
| # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, | ||
| # this makes thing complicated. We need to remove this mapper after refactor | ||
| # `TransformersModel` in the future. | ||
| # NOTE: `SupportsQuant` can be updated after property decorator is removed | ||
| @property | ||
| def hf_to_vllm_mapper(self): | ||
| prefix_mapper = { | ||
| name: "model." + name | ||
| for name, _ in self.model.model.named_children() | ||
| } | ||
| return WeightsMapper( | ||
| orig_to_new_substr={"model.": "model.model."}, | ||
| orig_to_new_prefix=prefix_mapper, | ||
| ) | ||
| self.transformers_model.make_empty_intermediate_tensors) | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -738,8 +719,9 @@ def forward( | |
| intermediate_tensors: Optional[IntermediateTensors] = None, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, IntermediateTensors]: | ||
| model_output = self.model(input_ids, positions, intermediate_tensors, | ||
| inputs_embeds) | ||
| model_output = self.transformers_model.forward(input_ids, positions, | ||
| intermediate_tensors, | ||
| inputs_embeds) | ||
| return model_output | ||
|
|
||
| def compute_logits( | ||
|
|
@@ -753,12 +735,10 @@ def compute_logits( | |
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, | ||
| torch.Tensor]]) -> set[str]: | ||
| loader = AutoWeightsLoader( | ||
| self, | ||
| skip_prefixes=(["lm_head."] | ||
| if self.config.tie_word_embeddings else None), | ||
| ) | ||
| return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) | ||
| skip_prefixes = ["lm_head." | ||
| ] if self.config.tie_word_embeddings else None | ||
| loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) | ||
| return loader.load_weights(weights) | ||
|
|
||
|
|
||
| @MULTIMODAL_REGISTRY.register_processor( | ||
|
|
@@ -770,6 +750,20 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, | |
| embedding_padding_modules = ["lm_head"] | ||
| embedding_modules = ["embed_tokens"] | ||
|
|
||
| # Backwards compatibility for prev released models | ||
| # State dicts back then had different formats | ||
| # and cannot be loaded with `AutoModel` mapping | ||
| # as is | ||
| hf_to_vllm_mapper = WeightsMapper( | ||
| orig_to_new_prefix={ | ||
| "language_model.model": "language_model", | ||
| "text_model.model": "text_model", | ||
| "text_model.lm_head": "lm_head", | ||
| "language_model.lm_head": "lm_head", | ||
| # deal with Qwen2-VL mapping | ||
| "model.layers": "language_model.layers", | ||
| }) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The updated The current mappings, such as For example, hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"text_model.lm_head": "model.lm_head",
"language_model.lm_head": "model.lm_head",
# deal with Qwen2-VL mapping
"model.layers": "model.language_model.layers",
}) |
||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| super().__init__() | ||
| config: PretrainedConfig = vllm_config.model_config.hf_config | ||
|
|
@@ -778,7 +772,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.config = config | ||
| self.dtype = vllm_config.model_config.dtype | ||
|
|
||
| self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) | ||
| self.transformers_model = TransformersModel(vllm_config=vllm_config, | ||
| prefix=prefix) | ||
| self.model = self.transformers_model.model | ||
| text_config = config.get_text_config() | ||
|
|
||
| if get_pp_group().is_last_rank: | ||
|
|
@@ -801,32 +797,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.lm_head = PPMissingLayer() | ||
|
|
||
| self.make_empty_intermediate_tensors = ( | ||
| self.model.make_empty_intermediate_tensors) | ||
|
|
||
| @property | ||
| def hf_to_vllm_mapper(self): | ||
| # Backwards compatibility for prev released models | ||
| # State dicts back then had different formats | ||
| # and cannot be loaded with `AutoModel` mapping | ||
| # as is | ||
| prefix_mapper = { | ||
| "language_model.model": "model.language_model", | ||
| "text_model.model": "model.text_model", | ||
| "vision_tower": "model.vision_tower", | ||
| "vqmodel": "model.vqmodel", | ||
| "vision_model": "model.vision_model", | ||
| "vision_embed_tokens": "model.vision_embed_tokens", | ||
| "image_newline": "model.image_newline", | ||
| "multi_modal_projector": "model.multi_modal_projector", | ||
| "text_model.lm_head": "lm_head", | ||
| "language_model.lm_head": "lm_head", | ||
| } | ||
| # Don't change the order for QwenVL | ||
| if 'Qwen2' in self.config.__class__.__name__: | ||
| prefix_mapper["model"] = "model.language_model" | ||
| prefix_mapper["visual"] = "model.visual" | ||
|
|
||
| return WeightsMapper(orig_to_new_prefix=prefix_mapper, ) | ||
| self.transformers_model.make_empty_intermediate_tensors) | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -846,8 +817,9 @@ def forward( | |
| input_ids, multimodal_embeds) | ||
| input_ids = None | ||
|
|
||
| model_output = self.model(input_ids, positions, intermediate_tensors, | ||
| inputs_embeds) | ||
| model_output = self.transformers_model.forward(input_ids, positions, | ||
| intermediate_tensors, | ||
| inputs_embeds) | ||
| return model_output | ||
|
|
||
| def compute_logits( | ||
|
|
@@ -896,7 +868,7 @@ def get_multimodal_embeddings(self, **kwargs): | |
| if isinstance(num_image_patches, list): | ||
| num_image_patches = torch.cat(num_image_patches) | ||
|
|
||
| vision_embeddings = self.model.model.get_image_features( | ||
| vision_embeddings = self.model.get_image_features( | ||
| pixel_values, | ||
| **{ | ||
| k: v.flatten(0, 1) | ||
|
|
@@ -926,7 +898,7 @@ def get_input_embeddings( | |
| input_ids: torch.Tensor, | ||
| multimodal_embeddings=None, | ||
| ) -> torch.Tensor: | ||
| inputs_embeds = self.model.model.get_input_embeddings()(input_ids) | ||
| inputs_embeds = self.model.get_input_embeddings()(input_ids) | ||
| if (multimodal_embeddings is not None | ||
| and len(multimodal_embeddings) != 0): | ||
| mask = (input_ids == self.config.image_token_id) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.