File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed
src/transformers/models/openai_moe Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -252,17 +252,16 @@ def write_model(
252252 "lm_head" ,
253253 ],
254254 }
255+ # required as we don't save the model with save_pretrained
256+ config .architectures = ["OpenAIMoeForCausalLM" ]
255257 config .save_pretrained (model_path )
256258 save_sharded_model (state_dict , model_path )
257259 del state_dict
258260
259- # Safety check: reload the converted model
260261 gc .collect ()
261- # TODO: remove when mxfp4 pr is merged
262- if not mxfp4 :
263- print ("Reloading the model to check if it's saved correctly." )
264- OpenAIMoeForCausalLM .from_pretrained (model_path , torch_dtype = torch .bfloat16 , device_map = "auto" )
265- print ("Model reloaded successfully." )
262+ print ("Reloading the model to check if it's saved correctly." )
263+ OpenAIMoeForCausalLM .from_pretrained (model_path , torch_dtype = torch .bfloat16 , device_map = "auto" )
264+ print ("Model reloaded successfully." )
266265
267266 # generation config
268267 if instruct :
You can’t perform that action at this time.
0 commit comments