Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 78 additions & 45 deletions unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"""

import torch
import bitsandbytes as bnb
try:
from huggingface_hub import get_token
except:
Expand All @@ -66,6 +67,18 @@
import tempfile
from peft import PeftModelForCausalLM

def find_skipped_quantized_modules(model):
skipped_modules = []
quantized_modules = []
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
if hasattr(module.weight, 'quant_state') and module.weight.quant_state is not None:
quantized_modules.append(name)
else:
skipped_modules.append(name)
elif isinstance(module, torch.nn.Linear):
skipped_modules.append(name)
return skipped_modules, quantized_modules

def create_huggingface_repo(
model,
Expand Down Expand Up @@ -320,7 +333,7 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict


@torch.inference_mode
def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dtype,):
def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dtype):
# All Unsloth Zoo code licensed under LGPLv3
# Merges LoRA and overwrites the safetensors file it was merged to
filename = os.path.join(save_directory, filename)
Expand Down Expand Up @@ -525,6 +538,7 @@ def merge_and_overwrite_lora(
push_to_hub = False,
private = False,
token = None,
save_method = "lora",
output_dtype = None,
low_disk_space_usage = False,
use_temp_file = False,
Expand All @@ -535,6 +549,8 @@ def merge_and_overwrite_lora(
inner_model = model.base_model.model if isinstance(model, PeftModelForCausalLM) else model
inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model

base_model = model.base_model if isinstance(model, PeftModelForCausalLM) else model

try:
model_name = get_model_name(model.config._name_or_path, load_in_4bit = False)
except:
Expand Down Expand Up @@ -596,66 +612,83 @@ def upload_items(filename = None):

# Save config / generation_config via no state_dict and tokenizer
if tokenizer is not None: tokenizer.save_pretrained(save_directory = save_directory,)
inner_model.save_pretrained(
save_directory = save_directory,
state_dict = {},
)

if save_method == "merged_16bit":
inner_model.save_pretrained(
save_directory = save_directory,
state_dict = {},
)
_remove_quantization_config(config_path = Path(save_directory) / "config.json")
elif save_method == "merged_4bit":
print(f"Unsloth: Saving model 4bit...")
base_model = base_model.merge_and_unload()
skipped_modules, quantized_modules = find_skipped_quantized_modules(base_model)
if len(skipped_modules) > 0:
# Reconstruct skipped modules so that it can be loaded
base_model.config.quantization_config["llm_int8_skip_modules"] = skipped_modules

base_model.save_pretrained(
save_directory = save_directory,
)
# Remove the quantization_config in the config.json file if it exists,
# as we are exporting the model in 16-bit format.
_remove_quantization_config(config_path = Path(save_directory) / "config.json")

if push_to_hub: upload_items()

safe_tensor_index_files = ["model.safetensors.index.json"] if len(safetensors_list) > 1 else []
if not low_disk_space_usage:
# Download all safetensors in 1 go!
print(f"Downloading safetensors for {model_name}...")
snapshot_download(
repo_id = model_name,
local_dir = save_directory,
allow_patterns = safe_tensor_index_files + safetensors_list,
)
elif safe_tensor_index_files:
print(f"Downloading safetensors index for {model_name}...")
snapshot_download(
repo_id = model_name,
local_dir = save_directory,
allow_patterns = ["model.safetensors.index.json"],
)
for filename in ProgressBar(safetensors_list, desc = "Unsloth: Merging weights into 16bit"):
if low_disk_space_usage:
hf_hub_download(
if save_method == "merged_16bit":
safe_tensor_index_files = ["model.safetensors.index.json"] if len(safetensors_list) > 1 else []
if not low_disk_space_usage:
# Download all safetensors in 1 go!
print(f"Downloading safetensors for {model_name}...")
snapshot_download(
repo_id = model_name,
local_dir = save_directory,
allow_patterns = safe_tensor_index_files + safetensors_list,
)
elif safe_tensor_index_files:
print(f"Downloading safetensors index for {model_name}...")
snapshot_download(
repo_id = model_name,
filename = filename,
repo_type = "model",
local_dir = save_directory,
allow_patterns = ["model.safetensors.index.json"],
)

for filename in ProgressBar(safetensors_list, desc = "Unsloth: Merging weights into 16bit"):
if low_disk_space_usage:
hf_hub_download(
repo_id = model_name,
filename = filename,
repo_type = "model",
local_dir = save_directory,
)
pass
n_saved_modules += _merge_and_overwrite_lora(
save_directory = save_directory,
filename = filename,
lora_weights = lora_weights,
output_dtype = output_dtype,
)
torch.cuda.empty_cache()
if low_disk_space_usage and push_to_hub:
upload_items(filename)
os.remove(os.path.join(save_directory, filename)) # Remove to conserve disk space
pass
pass
n_saved_modules += _merge_and_overwrite_lora(
save_directory = save_directory,
filename = filename,
lora_weights = lora_weights,
output_dtype = output_dtype,
)
torch.cuda.empty_cache()
if low_disk_space_usage and push_to_hub:
upload_items(filename)
os.remove(os.path.join(save_directory, filename)) # Remove to conserve disk space

# Check for errors
if len(lora_weights) != n_saved_modules:
raise RuntimeError(
f"Unsloth: Saving LoRA finetune failed since # of LoRAs = {len(lora_weights)} "\
f"does not match # of saved modules = {n_saved_modules}. Please file a bug report!"
)
pass
pass
if not low_disk_space_usage and push_to_hub: upload_items()

# Check for errors
if len(lora_weights) != n_saved_modules:
raise RuntimeError(
f"Unsloth: Saving LoRA finetune failed since # of LoRAs = {len(lora_weights)} "\
f"does not match # of saved modules = {n_saved_modules}. Please file a bug report!"
)
pass
if temp_file is not None:
try: temp_file.cleanup()
except: pass
pass

return save_directory
pass

Expand Down