6565 r"layers.(\d+).feed_forward.w3.weight" : r"language_model.model.layers.\1.feed_forward.up_proj.weight" , # might need to be fused for efficiency?
6666 # r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
6767 r"layers.(\d+).feed_forward.mlp.fc2_weight" : r"language_model.model.layers.\1.feed_forward.down_proj.weight" ,
68+ r"layers.(\d+).feed_forward.w2.weight" : r"language_model.model.layers.\1.feed_forward.down_proj.weight" ,
6869 r"layers.(\d+).feed_forward.mlp.layer_norm.weight" : r"language_model.model.layers.\1.post_attention_layernorm.weight" ,
6970
7071 # Vision encoder mapping
@@ -166,8 +167,8 @@ def get_concat_dim(key):
166167 return 0
167168
168169
169- def compute_intermediate_size (hidden_dim , multiple_of = 1024 , ffn_dim_multiplier = 1.3 ):
170- hidden_dim = 4 * int (2 * hidden_dim / 3 )
170+ def compute_intermediate_size (hidden_dim , ffn_exp = 4 , multiple_of = 1024 , ffn_dim_multiplier = 1.2 ):
171+ hidden_dim = ffn_exp * int (2 * hidden_dim / 3 )
171172 hidden_dim = int (ffn_dim_multiplier * hidden_dim )
172173 hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
173174 return hidden_dim
@@ -203,6 +204,8 @@ def max_context_length(model_path, instruct=False):
203204 with open (os .path .join (model_path , "params.json" ), "r" ) as f :
204205 params = json .load (f )
205206 params = params .get ("model" , params )
207+ if params .get ("moe_args" ) is None :
208+ return 8192
206209 num_experts = params ["moe_args" ]["num_experts" ]
207210 return 10485760 if num_experts == 16 else 1048576
208211
@@ -242,24 +245,40 @@ def write_model(
242245 # some constants from original code
243246 rope_scaling = {
244247 "rope_type" : "llama3" ,
245- "factor" : 8.0 ,
248+ "factor" : params . get ( "rope_scaling_factor" , 8.0 ) ,
246249 "low_freq_factor" : 1.0 ,
247- "high_freq_factor" : 4.0 ,
250+ "high_freq_factor" : params . get ( "rope_high_freq_factor" , 4.0 ) ,
248251 "original_max_position_embeddings" : 8192 ,
249252 }
250253 config_kwargs .update ({"rope_scaling" : rope_scaling })
251254
255+ if attention_chunk_size is None :
256+ config_kwargs .update ({"cache_implementation" : "static" })
257+
252258 # compute additional params for weight conversion
253259 num_heads_per_shard = num_heads // num_shards
254260 dim_per_head = dim // num_heads
255- # intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
261+ intermediate_size_mlp = compute_intermediate_size (
262+ dim ,
263+ ffn_exp = params ["ffn_exp" ],
264+ multiple_of = params ["multiple_of" ],
265+ ffn_dim_multiplier = params ["ffn_dim_multiplier" ],
266+ )
256267
257268 num_key_value_heads = params ["n_kv_heads" ] # for GQA / MQA
258269
259- num_experts = params ["moe_args" ]["num_experts" ]
260- interleave_moe_layer_step = params ["moe_args" ].get ("interleave_moe_layer_step" , 1 )
270+ if hasattr (params , "moe_args" ):
271+ num_experts = params ["moe_args" ]["num_experts" ]
272+ interleave_moe_layer_step = params ["moe_args" ].get ("interleave_moe_layer_step" , 1 )
273+ else :
274+ # Dense model (possibly Llama Guard) - disable all moe layers
275+ num_experts = 0
276+ interleave_moe_layer_step = 0
277+ config_kwargs .update ({"moe_layers" : []})
261278
279+ # Ensure all layers are rope if `nope_layer_interval` is None
262280 no_rope_layer_interval = params ["nope_layer_interval" ]
281+ no_rope_layer_interval = num_heads * 2 if no_rope_layer_interval is None else no_rope_layer_interval
263282
264283 bos_token_id = 200000
265284 eos_token_id = [200001 , 200007 , 200008 ] if instruct else 200001
@@ -273,7 +292,7 @@ def write_model(
273292 rope_theta = rope_theta ,
274293 num_hidden_layers = num_layers ,
275294 intermediate_size = 8192 ,
276- intermediate_size_mlp = 16384 ,
295+ intermediate_size_mlp = intermediate_size_mlp ,
277296 max_position_embeddings = max_context_length (input_base_path , instruct ),
278297 num_local_experts = num_experts ,
279298 interleave_moe_layer_step = interleave_moe_layer_step ,
@@ -336,7 +355,7 @@ def write_model(
336355 sharded_keys = []
337356 for _key in all_keys_raw :
338357 try :
339- if (loaded [0 ][_key ] == loaded [1 ][_key ]).all ():
358+ if num_shards == 1 or (loaded [0 ][_key ] == loaded [1 ][_key ]).all ():
340359 repeated_keys .append (_key )
341360 else :
342361 sharded_keys .append (_key )
@@ -354,7 +373,7 @@ def write_model(
354373 for key in tqdm (all_keys , desc = "Renaming and processing all keys" , unit = "key" ):
355374 new_key = new_keys [key ]
356375 print (key , new_key )
357- if not is_param_same_across_shards (new_key ):
376+ if num_shards > 1 and not is_param_same_across_shards (new_key ):
358377 current_parameter = [chunk .pop (key ) for chunk in loaded if not isinstance (chunk [key ], io .BytesIO )]
359378 else :
360379 print (f"{ key } (now { new_key } ) is the same across all shards." )
@@ -565,8 +584,8 @@ def get_reserved_special_tokens(name, count, start_index=0):
565584 "<|python_end|>" ,
566585 "<|finetune_right_pad|>" ,
567586] + get_reserved_special_tokens (
568- "text_post_train" , 61 , 6
569- ) # <|text_post_train_reserved_special_token_6 |>, ..., <|text_post_train_reserved_special_token_66 |>
587+ "text_post_train" , 61 , 8
588+ ) # <|text_post_train_reserved_special_token_8 |>, ..., <|text_post_train_reserved_special_token_68 |>
570589
571590# 200080, ..., 201133
572591LLAMA4_VISION_SPECIAL_TOKENS = [
@@ -621,15 +640,6 @@ def __init__(
621640 ** kwargs ,
622641 )
623642
624- # to check
625- # import tiktoken
626- # model = tiktoken.Encoding(
627- # name=Path(model_path).name,
628- # pat_str=self.O200K_PATTERN,
629- # mergeable_ranks=mergeable_ranks,
630- # special_tokens=self.special_tokens,
631- # )
632-
633643 instruct = chat_template is not None
634644 self .update_post_processor (self .converted_tokenizer )
635645 # finer special_tokens_map.json
@@ -687,12 +697,10 @@ def write_tokenizer(args):
687697 parser .add_argument (
688698 "--input_dir" ,
689699 type = str ,
690- default = "/fsx/arthur/Llama-4-17B-Omni-Instruct-Original" ,
691700 help = "Location of the local folder copied from the Hub." ,
692701 )
693702 parser .add_argument (
694703 "--output_dir" ,
695- default = "llama4_hf_vision" ,
696704 type = str ,
697705 help = "Location to write HF model and tokenizer" ,
698706 )
0 commit comments