Skip to content

Commit 683c627

Browse files
suiyoubiguyueh1
authored andcommitted
Llama4 Export: Remove outdated MLP weight transform (NVIDIA-NeMo#14297)
* Update HFLlamaExporter to remove outdated MLP weight transform for Llama4 model Signed-off-by: Ao Tang <[email protected]> * Apply isort and black reformatting Signed-off-by: suiyoubi <[email protected]> --------- Signed-off-by: Ao Tang <[email protected]> Signed-off-by: suiyoubi <[email protected]> Co-authored-by: suiyoubi <[email protected]> Signed-off-by: Guyue Huang <[email protected]>
1 parent 8755691 commit 683c627

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

nemo/collections/llm/gpt/model/llama.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,13 @@ def convert_state(self, source, target, source_config=None):
903903
"decoder.layers.*.mlp.experts.linear_fc1.weight": "model.layers.*.feed_forward.experts.gate_up_proj",
904904
}
905905
)
906+
907+
# Remove the transform with source_key "decoder.layers.*.mlp.linear_fc1.weight" from transforms
908+
# Llama4's HF model has a different mapping for the MLP weights (map to feed_forward instead of mlp)
909+
transforms = [
910+
t for t in transforms if getattr(t, "source_key", None) != "decoder.layers.*.mlp.linear_fc1.weight"
911+
]
912+
906913
transforms.extend(
907914
[
908915
io.state_transform(

0 commit comments

Comments
 (0)