Skip to content

Commit cf1cbd1

Browse files
jiminharegisss
authored andcommitted
Fix load INC load weights compile error due to Transformer 4.45 upgrade. (#1421)
1 parent 11f020d commit cf1cbd1

1 file changed

Lines changed: 48 additions & 1 deletion

File tree

examples/text-generation/utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ def setup_model(args, model_dtype, model_kwargs, logger):
245245
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
246246
)
247247
elif args.load_quantized_model_with_inc:
248-
from neural_compressor.torch.quantization import load
248+
#TODO: This will be removed in v1.19 Synapse release
249+
#Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
250+
import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl
251+
nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight
249252

253+
from neural_compressor.torch.quantization import load
250254
model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
251255
elif args.local_quantized_inc_model_path:
252256
org_model = AutoModelForCausalLM.from_pretrained(
@@ -662,3 +666,46 @@ def initialize_model(args, logger):
662666
logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}")
663667
logger.info(f"Model initialization took {(init_end - init_start):.3f}s")
664668
return model, assistant_model, tokenizer, generation_config
669+
670+
#TODO:This will be removed from Synapse v1.19 release.
671+
#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
672+
def local_load_remaining_pretrained_weight(self,model):
673+
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
674+
675+
resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
676+
torch_dtype = self.kwargs.pop("torch_dtype", torch.float32)
677+
dtype_orig = self.kwargs.pop("dtype_orig", None)
678+
offload_folder = self.kwargs.pop("offload_folder", None)
679+
offload_state_dict = self.kwargs.pop("offload_state_dict", False)
680+
681+
# restore default dtype
682+
if dtype_orig is not None:
683+
torch.set_default_dtype(dtype_orig)
684+
685+
if not isinstance(resolved_archive_file, list):
686+
resolved_archive_file = [resolved_archive_file]
687+
for shard_file in resolved_archive_file:
688+
state_dict = load_state_dict(shard_file)
689+
690+
params_dict={
691+
"model": model,
692+
"state_dict": state_dict,
693+
"start_prefix": "",
694+
"expected_keys": list(state_dict.keys()),
695+
"device_map": {"": self.device},
696+
"offload_folder": offload_folder,
697+
"state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None,
698+
"state_dict_index": {} if offload_state_dict else None,
699+
"dtype": torch_dtype,
700+
"keep_in_fp32_modules": [],
701+
}
702+
703+
_load_state_dict_into_meta_model(**params_dict)
704+
705+
# make sure token embedding weights are still tied if needed
706+
model.tie_weights()
707+
708+
# Set model in evaluation mode to deactivate DropOut modules by default
709+
model.eval()
710+
711+
return model

0 commit comments

Comments
 (0)