@@ -245,8 +245,12 @@ def setup_model(args, model_dtype, model_kwargs, logger):
245245 args .model_name_or_path , torch_dtype = model_dtype , quantization_config = quantization_config , ** model_kwargs
246246 )
247247 elif args .load_quantized_model_with_inc :
248- from neural_compressor .torch .quantization import load
248+ #TODO: This will be removed in v1.19 Synapse release
249+ #Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
250+ import neural_compressor .torch .algorithms .weight_only .save_load as nc_sl
251+ nc_sl .WOQModelLoader ._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight
249252
253+ from neural_compressor .torch .quantization import load
250254 model = load (model_name_or_path = args .model_name_or_path , format = "huggingface" , device = "hpu" , ** model_kwargs )
251255 elif args .local_quantized_inc_model_path :
252256 org_model = AutoModelForCausalLM .from_pretrained (
@@ -662,3 +666,46 @@ def initialize_model(args, logger):
662666 logger .info (f"device: { args .device } , n_hpu: { args .world_size } , bf16: { model_dtype == torch .bfloat16 } " )
663667 logger .info (f"Model initialization took { (init_end - init_start ):.3f} s" )
664668 return model , assistant_model , tokenizer , generation_config
669+
670+ #TODO:This will be removed from Synapse v1.19 release.
671+ #This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
672+ def local_load_remaining_pretrained_weight (self ,model ):
673+ from transformers .modeling_utils import _load_state_dict_into_meta_model , load_state_dict
674+
675+ resolved_archive_file = self .kwargs .pop ("resolved_archive_file" , None )
676+ torch_dtype = self .kwargs .pop ("torch_dtype" , torch .float32 )
677+ dtype_orig = self .kwargs .pop ("dtype_orig" , None )
678+ offload_folder = self .kwargs .pop ("offload_folder" , None )
679+ offload_state_dict = self .kwargs .pop ("offload_state_dict" , False )
680+
681+ # restore default dtype
682+ if dtype_orig is not None :
683+ torch .set_default_dtype (dtype_orig )
684+
685+ if not isinstance (resolved_archive_file , list ):
686+ resolved_archive_file = [resolved_archive_file ]
687+ for shard_file in resolved_archive_file :
688+ state_dict = load_state_dict (shard_file )
689+
690+ params_dict = {
691+ "model" : model ,
692+ "state_dict" : state_dict ,
693+ "start_prefix" : "" ,
694+ "expected_keys" : list (state_dict .keys ()),
695+ "device_map" : {"" : self .device },
696+ "offload_folder" : offload_folder ,
697+ "state_dict_folder" : tempfile .mkdtemp () if offload_state_dict else None ,
698+ "state_dict_index" : {} if offload_state_dict else None ,
699+ "dtype" : torch_dtype ,
700+ "keep_in_fp32_modules" : [],
701+ }
702+
703+ _load_state_dict_into_meta_model (** params_dict )
704+
705+ # make sure token embedding weights are still tied if needed
706+ model .tie_weights ()
707+
708+ # Set model in evaluation mode to deactivate DropOut modules by default
709+ model .eval ()
710+
711+ return model
0 commit comments