-
Notifications
You must be signed in to change notification settings - Fork 32.7k
Add llama 4 scaling #42045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add llama 4 scaling #42045
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from safetensors.torch import load_file | ||
|
|
||
| from transformers import ( | ||
| GenerationConfig, | ||
| Mistral3Config, | ||
| Mistral3ForConditionalGeneration, | ||
| MistralConfig, | ||
|
|
@@ -33,29 +34,29 @@ | |
| # fmt: off | ||
| STATE_DICT_MAPPING = { | ||
| # Text model keys | ||
| r"^output.weight": r"language_model.lm_head.weight", | ||
| r"^norm.weight": r"language_model.model.norm.weight", | ||
| r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight", | ||
| r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", | ||
| r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", | ||
| r"^output.weight": r"lm_head.weight", | ||
| r"^norm.weight": r"model.language_model.norm.weight", | ||
| r"^tok_embeddings.weight": r"model.language_model.embed_tokens.weight", | ||
| r"^layers.(\d+).attention_norm.weight": r"model.language_model.layers.\1.input_layernorm.weight", | ||
| r"^layers.(\d+).ffn_norm.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w1.weight": r"model.language_model.layers.\1.mlp.gate_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w2.weight": r"model.language_model.layers.\1.mlp.down_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w3.weight": r"model.language_model.layers.\1.mlp.up_proj.weight", | ||
|
|
||
| # Vision model keys | ||
| r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", | ||
| r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1", | ||
| r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2", | ||
| r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight", | ||
| r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight", | ||
| r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight", | ||
| r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight", | ||
| r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", | ||
| r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", | ||
| r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", | ||
| r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", | ||
| r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", | ||
| r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", | ||
| r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", | ||
| } | ||
| # fmt: on | ||
|
|
||
|
|
@@ -131,20 +132,48 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) | |
| similar_text_keys_to_keep = [ | ||
| "head_dim", | ||
| "vocab_size", | ||
| "rope_theta", | ||
| ] | ||
|
|
||
| new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()} | ||
| new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep}) | ||
| tie_word_embeddings = original_text_config.get("tied_embeddings", False) | ||
| new_text_config_kwargs["tie_word_embeddings"] = tie_word_embeddings | ||
| if original_config.get("yarn"): | ||
| new_text_config_kwargs["rope_parameters"] = { | ||
| "type": "yarn", | ||
| "rope_theta": original_config.get("rope_theta", 10000.0), | ||
| "factor": float(original_config["yarn"]["factor"]), | ||
| "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], | ||
| "beta_fast": float(original_config["yarn"]["beta"]), | ||
| "beta_slow": float(original_config["yarn"]["alpha"]), | ||
| "mscale_all_dim": 1, | ||
| } | ||
|
|
||
| if original_config.get("llama_4_scaling"): | ||
| new_text_config_kwargs["mscale"] = 1 | ||
| else: | ||
| new_text_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) | ||
|
|
||
| # These are not always defined depending on `params.json` | ||
| new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None) | ||
| new_text_config_kwargs["max_position_embeddings"] = original_text_config.get( | ||
| "max_seq_len", max_position_embeddings | ||
| "max_position_embeddings", original_text_config.get("max_seq_len", max_position_embeddings) | ||
| ) | ||
| # This may sometimes be a string in `params.json` | ||
| if new_text_config_kwargs["sliding_window"] is not None: | ||
| new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) | ||
|
|
||
| if original_config.get("llama_4_scaling"): | ||
| assert original_config.get("yarn") is not None, "llama_4_scaling is only supported with yarn" | ||
| new_text_config_kwargs["rope_parameters"]["llama_4_scaling_beta"] = original_config["llama_4_scaling"]["beta"] | ||
|
|
||
| new_text_config = MistralConfig(**new_text_config_kwargs) | ||
|
|
||
| # Vision config | ||
| if original_vision_config is None: | ||
| new_config = Mistral3Config(text_config=new_text_config, vision_feature_layer=-1) | ||
| return new_config | ||
|
|
||
| # Vision config | ||
| new_vision_config = original_vision_config | ||
| adapter_bias = new_vision_config.pop("adapter_bias", False) | ||
|
|
@@ -155,7 +184,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) | |
| _ = new_vision_config.pop("image_break_token_id", 12) | ||
| _ = new_vision_config.pop("image_end_token_id", 13) | ||
| _ = new_vision_config.pop("max_image_size") | ||
| new_vision_config = PixtralVisionConfig(**new_vision_config) | ||
| new_vision_config = PixtralVisionConfig(hidden_act="silu", **new_vision_config) | ||
|
|
||
| new_config = Mistral3Config( | ||
| vision_config=new_vision_config, | ||
|
|
@@ -181,19 +210,25 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd | |
| new_dict = convert_state_dict(original_state_dict, config) | ||
| full_state_dict.update(new_dict) | ||
|
|
||
| if config.text_config.tie_word_embeddings: | ||
| full_state_dict["lm_head.weight"] = full_state_dict["model.language_model.embed_tokens.weight"] | ||
|
|
||
| # Load weights into model and resave them | ||
| with torch.device("meta"): | ||
| model = Mistral3ForConditionalGeneration(config) | ||
| model.load_state_dict(full_state_dict, strict=True, assign=True) | ||
| model.save_pretrained(output_dir) | ||
| return config | ||
|
|
||
|
|
||
| def convert_and_write_processor(input_dir: str, output_dir: str): | ||
| def convert_and_write_processor(input_dir: str, output_dir: str, model_config: Mistral3ForConditionalGeneration): | ||
| """Convert the tokenizer and save it.""" | ||
| from mistral_common.tokens.tokenizers.tekken import Tekkenizer | ||
|
|
||
| tokenizer_file = os.path.join(input_dir, "tekken.json") | ||
| tokenizer = convert_tekken_tokenizer(tokenizer_file) | ||
| tokenizer.add_special_tokens({"pad_token": "<pad>"}) | ||
| chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}' | ||
| tekkenizer = Tekkenizer.from_file(tokenizer_file) | ||
|
|
||
| config = read_json(os.path.join(input_dir, "params.json")) | ||
| patch_size = config["vision_encoder"]["patch_size"] | ||
|
|
@@ -206,13 +241,21 @@ def convert_and_write_processor(input_dir: str, output_dir: str): | |
| image_processor=image_processor, | ||
| image_token="[IMG]", | ||
| patch_size=patch_size, | ||
| chat_template=chat_template, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model has no chat template anymore? 🥲
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes no sense anymore to put one here unfortunately, depending on the models the chat template looks very different (tokenizer version, Thinking or not, ...). So having a default one is arguably worse than none at all imo. |
||
| spatial_merge_size=spatial_merge_size, | ||
| ) | ||
|
|
||
| # Finally save it | ||
| processor.save_pretrained(output_dir) | ||
|
|
||
| generation_config = GenerationConfig( | ||
| eos_token_id=tekkenizer.eos_id, | ||
| bos_token_id=tekkenizer.bos_id, | ||
| pad_token_id=tekkenizer.pad_id, | ||
| max_length=model_config.text_config.max_position_embeddings, | ||
| ) | ||
|
|
||
| generation_config.save_pretrained(output_dir) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
|
|
@@ -233,8 +276,8 @@ def main(): | |
|
|
||
| args = parser.parse_args() | ||
|
|
||
| convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings) | ||
| convert_and_write_processor(args.input_dir, args.output_dir) | ||
| config = convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings) | ||
| convert_and_write_processor(args.input_dir, args.output_dir, config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be a bit breaking for configs that were saved without
hidden_actand therefore defaulted to GeLU prevThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was broken before actually, it is a fix here.