From 4f43efb8f5ead22c9cf7fcaaa1c519d4d7da3c99 Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Fri, 31 Oct 2025 15:56:56 +0100 Subject: [PATCH 1/4] Add llama 4 scaling --- src/transformers/integrations/mistral.py | 12 +-- .../models/mistral/configuration_mistral.py | 21 +++- .../mistral/convert_mistral_weights_to_hf.py | 34 +++++- .../models/mistral/modular_mistral.py | 18 ++++ .../models/mistral3/configuration_mistral3.py | 2 +- .../convert_mistral3_weights_to_hf.py | 100 +++++++++++++----- 6 files changed, 146 insertions(+), 41 deletions(-) diff --git a/src/transformers/integrations/mistral.py b/src/transformers/integrations/mistral.py index cdf237645fc1..0d9c1cb5f69a 100644 --- a/src/transformers/integrations/mistral.py +++ b/src/transformers/integrations/mistral.py @@ -84,20 +84,18 @@ def convert_tekken_tokenizer(tokenizer_file: str): # Extract vocab and special tokens vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial - all_special = [ - token.value if hasattr(token, "value") else token - for token in mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens - ] - specials_tokens = {token: all_special.index(token) for token in all_special} + sorted_tokens = sorted(mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens, key=lambda x: x["rank"]) + all_specials = [token["token_str"] for token in sorted_tokens] + specials_tokens = {token: idx for idx, token in enumerate(all_specials)} specials_tokens.update(vocab) vocab = specials_tokens # Convert tokenizer = LlamaTokenizerFast( - tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), + tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_specials).converted(), ) # Post-process - tokenizer.add_special_tokens({"additional_special_tokens": all_special}) + tokenizer.add_special_tokens({"additional_special_tokens": all_specials}) return tokenizer diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index 0fac55d26e2a..5e15619996e6 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -14,7 +14,7 @@ # limitations under the License. """Mistral model configuration""" -from typing import Optional +from typing import Optional, TypedDict from ...configuration_utils import PreTrainedConfig from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params @@ -24,6 +24,21 @@ logger = logging.get_logger(__name__) +class LLama4Scaling(TypedDict): + r""" + The scaling parameters to apply LLama 4 scaling to the rope embeddings. + + Args: + original_max_position_embeddings (`int`): + The original max position embeddings used during pretraining. + scaling_beta (`float`): + The scaling beta parameter. + """ + + original_max_position_embeddings: int + scaling_beta: float + + class MistralConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an @@ -82,6 +97,8 @@ class MistralConfig(PreTrainedConfig): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. + llama_4_scaling (`LLama4Scaling`, *optional*): + Dictionary containing the scaling parameters for LLama 4 scaling. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention window size. If not specified, will default to `4096`. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -137,6 +154,7 @@ def __init__( eos_token_id: Optional[int] = 2, tie_word_embeddings: Optional[bool] = False, rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + llama_4_scaling: Optional[LLama4Scaling] = None, sliding_window: Optional[int] = 4096, attention_dropout: Optional[float] = 0.0, **kwargs, @@ -160,6 +178,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.attention_dropout = attention_dropout + self.llama_4_scaling = llama_4_scaling if "layer_types" in kwargs: logger.warning_once( diff --git a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py index a790fed81d1b..6cc2cfdac15a 100644 --- a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py +++ b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py @@ -166,6 +166,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 32768): "intermediate_size": "hidden_dim", "num_attention_heads": "n_heads", "rms_norm_eps": "norm_eps", + "tie_word_embeddings": "tie_embeddings", } similar_keys_to_keep = [ "head_dim", @@ -180,13 +181,33 @@ def convert_config(original_config: dict, max_position_embeddings: int = 32768): new_config_kwargs["num_key_value_heads"] = original_config.get( "n_kv_heads", new_config_kwargs["num_attention_heads"] ) - new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) - new_config_kwargs["max_position_embeddings"] = original_config.get("max_seq_len", max_position_embeddings) + new_config_kwargs["max_position_embeddings"] = original_config.get( + "max_position_embeddings", max_position_embeddings + ) + + if original_config.get("yarn"): + new_config_kwargs["rope_parameters"] = { + "type": "yarn", + "rope_theta": original_config.get("rope_theta", 10000.0), + "factor": original_config["yarn"]["factor"], + "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], + "beta_fast": original_config["yarn"]["beta"], + "beta_slow": original_config["yarn"]["alpha"], + "mscale_all_dim": 1, + } + + if original_config.get("llama_4_scaling"): + new_config_kwargs["mscale"] = 1 + else: + new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) # This may sometimes be a string in `params.json` if new_config_kwargs["sliding_window"] is not None: new_config_kwargs["sliding_window"] = int(new_config_kwargs["sliding_window"]) + if original_config.get("llama_4_scaling"): + new_config_kwargs["llama_4_scaling"] = original_config["llama_4_scaling"] + new_config = MistralConfig(**new_config_kwargs) return new_config @@ -226,6 +247,7 @@ def convert_and_write_tokenizer(input_dir: str, output_dir: str, tokenizer_templ if "tekken.json" in os.listdir(input_dir): tokenizer_file = os.path.join(input_dir, "tekken.json") tokenizer = convert_tekken_tokenizer(tokenizer_file) + tokenizer.add_special_tokens({"pad_token": ""}) else: # May have .v3 or .v7 at the end tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0] @@ -272,12 +294,18 @@ def main(): action="store_true", help="If passed, will only convert the tokenizer.", ) + parser.add_argument( + "--model_only", + action="store_true", + help="If passed, will only convert the tokenizer.", + ) args = parser.parse_args() if not args.tokenizer_only: convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split) - convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name) + if not args.model_only: + convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name) if __name__ == "__main__": diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 709ff855c399..889a8c638b81 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -46,11 +46,26 @@ class MistralAttention(LlamaAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + if self.config.llama_4_scaling is not None: + self.llama_4_scaling_beta = self.config.llama_4_scaling.scaling_beta + self.llama_4_scaling_original_max_position_embeddings = ( + self.config.llama_4_scaling.original_max_position_embeddings + ) + else: + self.llama_4_scaling_beta = None + self.llama_4_scaling_original_max_position_embeddings = None + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + def _get_llama4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: + scaling = 1 + self.llama_4_scaling_beta * torch.log( + 1 + torch.floor(positions_ids / self.llama_4_scaling_original_max_position_embeddings) + ) + return scaling.unsqueeze(-1) + def forward( self, hidden_states: torch.Tensor, @@ -70,6 +85,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if self.llama_4_scaling_beta is not None: + query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype) + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/mistral3/configuration_mistral3.py b/src/transformers/models/mistral3/configuration_mistral3.py index 59851c135987..03c5f51fd821 100644 --- a/src/transformers/models/mistral3/configuration_mistral3.py +++ b/src/transformers/models/mistral3/configuration_mistral3.py @@ -103,7 +103,7 @@ def __init__( num_attention_heads=16, vocab_size=32000, head_dim=64, - hidden_act="gelu", + hidden_act="silu", ) self.vision_config = vision_config diff --git a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py index c8f9b64ab1f6..1013a92b2a60 100644 --- a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py +++ b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py @@ -20,6 +20,7 @@ from safetensors.torch import load_file from transformers import ( + GenerationConfig, Mistral3Config, Mistral3ForConditionalGeneration, MistralConfig, @@ -33,29 +34,29 @@ # fmt: off STATE_DICT_MAPPING = { # Text model keys - r"^output.weight": r"language_model.lm_head.weight", - r"^norm.weight": r"language_model.model.norm.weight", - r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight", - r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", - r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", - r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight", - r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", - r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", - r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + r"^output.weight": r"lm_head.weight", + r"^norm.weight": r"model.language_model.norm.weight", + r"^tok_embeddings.weight": r"model.language_model.embed_tokens.weight", + r"^layers.(\d+).attention_norm.weight": r"model.language_model.layers.\1.input_layernorm.weight", + r"^layers.(\d+).ffn_norm.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight", + r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight", + r"^layers.(\d+).feed_forward.w1.weight": r"model.language_model.layers.\1.mlp.gate_proj.weight", + r"^layers.(\d+).feed_forward.w2.weight": r"model.language_model.layers.\1.mlp.down_proj.weight", + r"^layers.(\d+).feed_forward.w3.weight": r"model.language_model.layers.\1.mlp.up_proj.weight", # Vision model keys - r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", - r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", - r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", - r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1", - r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2", - r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight", - r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight", - r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight", - r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight", + r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", + r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", + r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", + r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", + r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", + r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", } # fmt: on @@ -131,20 +132,47 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) similar_text_keys_to_keep = [ "head_dim", "vocab_size", - "rope_theta", ] + new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()} new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep}) + tie_word_embeddings = original_text_config.get("tied_embeddings", False) + new_text_config_kwargs["tie_word_embeddings"] = tie_word_embeddings + if original_config.get("yarn"): + new_text_config_kwargs["rope_parameters"] = { + "type": "yarn", + "rope_theta": original_config.get("rope_theta", 10000.0), + "factor": float(original_config["yarn"]["factor"]), + "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], + "beta_fast": float(original_config["yarn"]["beta"]), + "beta_slow": float(original_config["yarn"]["alpha"]), + "mscale_all_dim": 1, + } + + if original_config.get("llama_4_scaling"): + new_text_config_kwargs["mscale"] = 1 + else: + new_text_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) + # These are not always defined depending on `params.json` new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None) new_text_config_kwargs["max_position_embeddings"] = original_text_config.get( - "max_seq_len", max_position_embeddings + "max_position_embeddings", original_text_config.get("max_seq_len", max_position_embeddings) ) # This may sometimes be a string in `params.json` if new_text_config_kwargs["sliding_window"] is not None: new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) + + if original_config.get("llama_4_scaling"): + new_text_config_kwargs["llama_4_scaling"] = original_config["llama_4_scaling"] + new_text_config = MistralConfig(**new_text_config_kwargs) + # Vision config + if original_vision_config is None: + new_config = Mistral3Config(text_config=new_text_config, vision_feature_layer=-1) + return new_config + # Vision config new_vision_config = original_vision_config adapter_bias = new_vision_config.pop("adapter_bias", False) @@ -155,7 +183,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) _ = new_vision_config.pop("image_break_token_id", 12) _ = new_vision_config.pop("image_end_token_id", 13) _ = new_vision_config.pop("max_image_size") - new_vision_config = PixtralVisionConfig(**new_vision_config) + new_vision_config = PixtralVisionConfig(hidden_act="silu", **new_vision_config) new_config = Mistral3Config( vision_config=new_vision_config, @@ -181,19 +209,25 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd new_dict = convert_state_dict(original_state_dict, config) full_state_dict.update(new_dict) + if config.text_config.tie_word_embeddings: + full_state_dict["lm_head.weight"] = full_state_dict["model.language_model.embed_tokens.weight"] + # Load weights into model and resave them with torch.device("meta"): model = Mistral3ForConditionalGeneration(config) model.load_state_dict(full_state_dict, strict=True, assign=True) model.save_pretrained(output_dir) + return config -def convert_and_write_processor(input_dir: str, output_dir: str): +def convert_and_write_processor(input_dir: str, output_dir: str, model_config: Mistral3ForConditionalGeneration): """Convert the tokenizer and save it.""" + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + tokenizer_file = os.path.join(input_dir, "tekken.json") tokenizer = convert_tekken_tokenizer(tokenizer_file) tokenizer.add_special_tokens({"pad_token": ""}) - chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}' + tekkenizer = Tekkenizer.from_file(tokenizer_file) config = read_json(os.path.join(input_dir, "params.json")) patch_size = config["vision_encoder"]["patch_size"] @@ -206,13 +240,21 @@ def convert_and_write_processor(input_dir: str, output_dir: str): image_processor=image_processor, image_token="[IMG]", patch_size=patch_size, - chat_template=chat_template, spatial_merge_size=spatial_merge_size, ) # Finally save it processor.save_pretrained(output_dir) + generation_config = GenerationConfig( + eos_token_id=tekkenizer.eos_id, + bos_token_id=tekkenizer.bos_id, + pad_token_id=tekkenizer.pad_id, + max_length=model_config.text_config.max_position_embeddings, + ) + + generation_config.save_pretrained(output_dir) + def main(): parser = argparse.ArgumentParser() @@ -233,8 +275,8 @@ def main(): args = parser.parse_args() - convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings) - convert_and_write_processor(args.input_dir, args.output_dir) + config = convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings) + convert_and_write_processor(args.input_dir, args.output_dir, config) if __name__ == "__main__": From e4c780d156db3cd0ad795da5f1ae8b81a3be17c5 Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Fri, 7 Nov 2025 08:22:43 +0100 Subject: [PATCH 2/4] Refactor llama_4_scaling inside rope_parameters --- .../models/mistral/configuration_mistral.py | 4 +--- .../mistral/convert_mistral_weights_to_hf.py | 3 ++- src/transformers/models/mistral/modular_mistral.py | 14 +++----------- .../mistral3/convert_mistral3_weights_to_hf.py | 3 ++- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index 5e15619996e6..b51a5e5004ed 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -154,7 +154,6 @@ def __init__( eos_token_id: Optional[int] = 2, tie_word_embeddings: Optional[bool] = False, rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, - llama_4_scaling: Optional[LLama4Scaling] = None, sliding_window: Optional[int] = 4096, attention_dropout: Optional[float] = 0.0, **kwargs, @@ -178,7 +177,6 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.attention_dropout = attention_dropout - self.llama_4_scaling = llama_4_scaling if "layer_types" in kwargs: logger.warning_once( @@ -192,7 +190,7 @@ def __init__( # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 10000.0) standardize_rope_params(self, rope_theta=rope_theta) - rope_config_validation(self) + rope_config_validation(self, ignore_keys={"llama_4_scaling_beta"}) super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py index 6cc2cfdac15a..538409415914 100644 --- a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py +++ b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py @@ -206,7 +206,8 @@ def convert_config(original_config: dict, max_position_embeddings: int = 32768): new_config_kwargs["sliding_window"] = int(new_config_kwargs["sliding_window"]) if original_config.get("llama_4_scaling"): - new_config_kwargs["llama_4_scaling"] = original_config["llama_4_scaling"] + assert original_config.get("yarn") is not None, "llama_4_scaling is only supported with yarn" + new_config_kwargs["rope_parameters"]["llama_4_scaling_beta"] = original_config["llama_4_scaling"]["beta"] new_config = MistralConfig(**new_config_kwargs) return new_config diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 889a8c638b81..5697839c2b6d 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -46,14 +46,6 @@ class MistralAttention(LlamaAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - if self.config.llama_4_scaling is not None: - self.llama_4_scaling_beta = self.config.llama_4_scaling.scaling_beta - self.llama_4_scaling_original_max_position_embeddings = ( - self.config.llama_4_scaling.original_max_position_embeddings - ) - else: - self.llama_4_scaling_beta = None - self.llama_4_scaling_original_max_position_embeddings = None self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) @@ -61,8 +53,8 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def _get_llama4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: - scaling = 1 + self.llama_4_scaling_beta * torch.log( - 1 + torch.floor(positions_ids / self.llama_4_scaling_original_max_position_embeddings) + scaling = 1 + self.config.rope_parameters.llama_4_scaling_beta * torch.log( + 1 + torch.floor(positions_ids / self.config.rope_parameters.original_max_position_embeddings) ) return scaling.unsqueeze(-1) @@ -85,7 +77,7 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if self.llama_4_scaling_beta is not None: + if self.config.rope_parameters.llama_4_scaling_beta is not None: query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype) if past_key_values is not None: diff --git a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py index 1013a92b2a60..f438bf2c299a 100644 --- a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py +++ b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py @@ -164,7 +164,8 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) if original_config.get("llama_4_scaling"): - new_text_config_kwargs["llama_4_scaling"] = original_config["llama_4_scaling"] + assert original_config.get("yarn") is not None, "llama_4_scaling is only supported with yarn" + new_text_config_kwargs["rope_parameters"]["llama_4_scaling_beta"] = original_config["llama_4_scaling"]["beta"] new_text_config = MistralConfig(**new_text_config_kwargs) From 573f0a160bd7f6e332858b1b3eb3baf5d57dd60a Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Fri, 7 Nov 2025 14:32:27 +0000 Subject: [PATCH 3/4] Remove dead code --- .../models/mistral/configuration_mistral.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index b51a5e5004ed..b873e77a1132 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -14,7 +14,7 @@ # limitations under the License. """Mistral model configuration""" -from typing import Optional, TypedDict +from typing import Optional from ...configuration_utils import PreTrainedConfig from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params @@ -24,21 +24,6 @@ logger = logging.get_logger(__name__) -class LLama4Scaling(TypedDict): - r""" - The scaling parameters to apply LLama 4 scaling to the rope embeddings. - - Args: - original_max_position_embeddings (`int`): - The original max position embeddings used during pretraining. - scaling_beta (`float`): - The scaling beta parameter. - """ - - original_max_position_embeddings: int - scaling_beta: float - - class MistralConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an @@ -97,8 +82,6 @@ class MistralConfig(PreTrainedConfig): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. - llama_4_scaling (`LLama4Scaling`, *optional*): - Dictionary containing the scaling parameters for LLama 4 scaling. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention window size. If not specified, will default to `4096`. attention_dropout (`float`, *optional*, defaults to 0.0): From 3a6de8b7ebb248950b6329acea1302cea5d08ea5 Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Fri, 7 Nov 2025 14:58:54 +0000 Subject: [PATCH 4/4] Replace dot access by get --- src/transformers/models/mistral/modular_mistral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 5697839c2b6d..be225f84b6ce 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -53,7 +53,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def _get_llama4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: - scaling = 1 + self.config.rope_parameters.llama_4_scaling_beta * torch.log( + scaling = 1 + self.config.rope_parameters.get("llama_4_scaling_beta") * torch.log( 1 + torch.floor(positions_ids / self.config.rope_parameters.original_max_position_embeddings) ) return scaling.unsqueeze(-1) @@ -77,7 +77,7 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if self.config.rope_parameters.llama_4_scaling_beta is not None: + if self.config.rope_parameters.get("llama_4_scaling_beta") is not None: query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype) if past_key_values is not None: