diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index daa3dfa3..e83ae68b 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -49,6 +49,7 @@ """ import torch +import bitsandbytes as bnb try: from huggingface_hub import get_token except: @@ -66,6 +67,18 @@ import tempfile from peft import PeftModelForCausalLM +def find_skipped_quantized_modules(model): + skipped_modules = [] + quantized_modules = [] + for name, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + if hasattr(module.weight, 'quant_state') and module.weight.quant_state is not None: + quantized_modules.append(name) + else: + skipped_modules.append(name) + elif isinstance(module, torch.nn.Linear): + skipped_modules.append(name) + return skipped_modules, quantized_modules def create_huggingface_repo( model, @@ -320,7 +333,7 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict @torch.inference_mode -def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dtype,): +def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dtype): # All Unsloth Zoo code licensed under LGPLv3 # Merges LoRA and overwrites the safetensors file it was merged to filename = os.path.join(save_directory, filename) @@ -525,6 +538,7 @@ def merge_and_overwrite_lora( push_to_hub = False, private = False, token = None, + save_method = "lora", output_dtype = None, low_disk_space_usage = False, use_temp_file = False, @@ -535,6 +549,8 @@ def merge_and_overwrite_lora( inner_model = model.base_model.model if isinstance(model, PeftModelForCausalLM) else model inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model + base_model = model.base_model if isinstance(model, PeftModelForCausalLM) else model + try: model_name = get_model_name(model.config._name_or_path, load_in_4bit = False) except: @@ -596,66 +612,83 @@ def upload_items(filename = None): # Save config / generation_config via no state_dict and tokenizer if tokenizer is not None: tokenizer.save_pretrained(save_directory = save_directory,) - inner_model.save_pretrained( - save_directory = save_directory, - state_dict = {}, - ) + + if save_method == "merged_16bit": + inner_model.save_pretrained( + save_directory = save_directory, + state_dict = {}, + ) + _remove_quantization_config(config_path = Path(save_directory) / "config.json") + elif save_method == "merged_4bit": + print(f"Unsloth: Saving model 4bit...") + base_model = base_model.merge_and_unload() + skipped_modules, quantized_modules = find_skipped_quantized_modules(base_model) + if len(skipped_modules) > 0: + # Reconstruct skipped modules so that it can be loaded + base_model.config.quantization_config["llm_int8_skip_modules"] = skipped_modules + + base_model.save_pretrained( + save_directory = save_directory, + ) # Remove the quantization_config in the config.json file if it exists, # as we are exporting the model in 16-bit format. - _remove_quantization_config(config_path = Path(save_directory) / "config.json") if push_to_hub: upload_items() - safe_tensor_index_files = ["model.safetensors.index.json"] if len(safetensors_list) > 1 else [] - if not low_disk_space_usage: - # Download all safetensors in 1 go! - print(f"Downloading safetensors for {model_name}...") - snapshot_download( - repo_id = model_name, - local_dir = save_directory, - allow_patterns = safe_tensor_index_files + safetensors_list, - ) - elif safe_tensor_index_files: - print(f"Downloading safetensors index for {model_name}...") - snapshot_download( - repo_id = model_name, - local_dir = save_directory, - allow_patterns = ["model.safetensors.index.json"], - ) - for filename in ProgressBar(safetensors_list, desc = "Unsloth: Merging weights into 16bit"): - if low_disk_space_usage: - hf_hub_download( + if save_method == "merged_16bit": + safe_tensor_index_files = ["model.safetensors.index.json"] if len(safetensors_list) > 1 else [] + if not low_disk_space_usage: + # Download all safetensors in 1 go! + print(f"Downloading safetensors for {model_name}...") + snapshot_download( + repo_id = model_name, + local_dir = save_directory, + allow_patterns = safe_tensor_index_files + safetensors_list, + ) + elif safe_tensor_index_files: + print(f"Downloading safetensors index for {model_name}...") + snapshot_download( repo_id = model_name, - filename = filename, - repo_type = "model", local_dir = save_directory, + allow_patterns = ["model.safetensors.index.json"], + ) + + for filename in ProgressBar(safetensors_list, desc = "Unsloth: Merging weights into 16bit"): + if low_disk_space_usage: + hf_hub_download( + repo_id = model_name, + filename = filename, + repo_type = "model", + local_dir = save_directory, + ) + pass + n_saved_modules += _merge_and_overwrite_lora( + save_directory = save_directory, + filename = filename, + lora_weights = lora_weights, + output_dtype = output_dtype, ) + torch.cuda.empty_cache() + if low_disk_space_usage and push_to_hub: + upload_items(filename) + os.remove(os.path.join(save_directory, filename)) # Remove to conserve disk space + pass pass - n_saved_modules += _merge_and_overwrite_lora( - save_directory = save_directory, - filename = filename, - lora_weights = lora_weights, - output_dtype = output_dtype, - ) - torch.cuda.empty_cache() - if low_disk_space_usage and push_to_hub: - upload_items(filename) - os.remove(os.path.join(save_directory, filename)) # Remove to conserve disk space + + # Check for errors + if len(lora_weights) != n_saved_modules: + raise RuntimeError( + f"Unsloth: Saving LoRA finetune failed since # of LoRAs = {len(lora_weights)} "\ + f"does not match # of saved modules = {n_saved_modules}. Please file a bug report!" + ) pass - pass if not low_disk_space_usage and push_to_hub: upload_items() - # Check for errors - if len(lora_weights) != n_saved_modules: - raise RuntimeError( - f"Unsloth: Saving LoRA finetune failed since # of LoRAs = {len(lora_weights)} "\ - f"does not match # of saved modules = {n_saved_modules}. Please file a bug report!" - ) - pass if temp_file is not None: try: temp_file.cleanup() except: pass pass + return save_directory pass