Skip to content

Commit 6c234d5

Browse files
Merge pull request #2381 from Erland366/fix/saving_vlm_4bit
Fix saving 4bit for VLM
2 parents 73d6fb2 + c1155dc commit 6c234d5

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

tests/saving/test_unsloth_save.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import json
2+
import os
3+
import shutil
4+
import tempfile
5+
import pytest
6+
7+
from unsloth import FastLanguageModel, FastModel
8+
9+
model_to_test = [
10+
# Text Models
11+
"unsloth/tinyllama",
12+
"unsloth/tinyllama-bnb-4bit",
13+
"unsloth/Qwen2.5-0.5B-Instruct",
14+
"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit",
15+
"unsloth/Phi-4-mini-instruct",
16+
"unsloth/Phi-4-mini-instruct-bnb-4bit",
17+
"unsloth/Qwen2.5-0.5B",
18+
# Vision Models
19+
"unsloth/gemma-3-1b-it",
20+
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
21+
"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit"
22+
]
23+
24+
# Variables
25+
save_file_sizes = {}
26+
save_file_sizes["merged_16bit"] = {}
27+
save_file_sizes["merged_4bit"] = {}
28+
29+
tokenizer_files = [
30+
"tokenizer_config.json",
31+
"special_tokens_map.json",
32+
]
33+
34+
@pytest.fixture(scope="session", params=model_to_test)
35+
def loaded_model_tokenizer(request):
36+
model_name = request.param
37+
print("Loading model and tokenizer...")
38+
39+
model, tokenizer = FastModel.from_pretrained(
40+
model_name, # use small model
41+
max_seq_length=128,
42+
dtype=None,
43+
load_in_4bit=True,
44+
)
45+
46+
# Apply LoRA
47+
model = FastModel.get_peft_model(
48+
model,
49+
r=16,
50+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
51+
lora_alpha=16,
52+
use_gradient_checkpointing="unsloth",
53+
)
54+
55+
return model, tokenizer
56+
57+
@pytest.fixture(scope="session")
58+
def model(loaded_model_tokenizer):
59+
return loaded_model_tokenizer[0]
60+
61+
@pytest.fixture(scope="session")
62+
def tokenizer(loaded_model_tokenizer):
63+
return loaded_model_tokenizer[1]
64+
65+
@pytest.fixture
66+
def temp_save_dir():
67+
dir = tempfile.mkdtemp()
68+
print(f"Temporary directory created at: {dir}")
69+
yield dir
70+
print(f"Temporary directory deleted: {dir}")
71+
shutil.rmtree(dir)
72+
73+
74+
def delete_quantization_config(model):
75+
# Since merged, edit quantization_config
76+
old_config = model.config
77+
new_config = model.config.to_dict()
78+
if "quantization_config" in new_config:
79+
del new_config["quantization_config"]
80+
original_model = model
81+
new_config = type(model.config).from_dict(new_config)
82+
while hasattr(original_model, "model"):
83+
original_model = original_model.model
84+
original_model.config = new_config
85+
model.config = new_config
86+
87+
def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
88+
save_path = os.path.join(temp_save_dir, "unsloth_merged_16bit", model.config._name_or_path.replace("/", "_"))
89+
90+
model.save_pretrained_merged(
91+
save_path,
92+
tokenizer=tokenizer,
93+
save_method="merged_16bit"
94+
)
95+
96+
# Check model files
97+
assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
98+
assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
99+
100+
weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
101+
assert len(weight_files) > 0, "No weight files found in the save directory."
102+
103+
# Check tokenizer files
104+
for file in tokenizer_files:
105+
assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
106+
107+
# Check config to see if it is 16bit by checking for quantization config
108+
config_path = os.path.join(save_path, "config.json")
109+
with open(config_path, "r") as f:
110+
config = json.load(f)
111+
112+
assert "quantization_config" not in config, "Quantization config not found in the model config."
113+
114+
# Store the size of the model files
115+
total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
116+
save_file_sizes["merged_16bit"][model.config._name_or_path] = total_size
117+
print(f"Total size of merged_16bit files: {total_size} bytes")
118+
119+
# Test loading the model from the saved path
120+
loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
121+
save_path,
122+
max_seq_length=128,
123+
dtype=None,
124+
load_in_4bit=True,
125+
)
126+
127+
def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
128+
save_path = os.path.join(temp_save_dir, "unsloth_merged_4bit", model.config._name_or_path.replace("/", "_"))
129+
130+
model.save_pretrained_merged(
131+
save_path,
132+
tokenizer=tokenizer,
133+
save_method="merged_4bit_forced"
134+
)
135+
136+
# Check model files
137+
assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
138+
assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
139+
140+
weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
141+
assert len(weight_files) > 0, "No weight files found in the save directory."
142+
143+
# Check tokenizer files
144+
for file in tokenizer_files:
145+
assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
146+
147+
# Store the size of the model files
148+
total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
149+
save_file_sizes["merged_4bit"][model.config._name_or_path] = total_size
150+
151+
print(f"Total size of merged_4bit files: {total_size} bytes")
152+
153+
assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "Merged 4bit files are larger than merged 16bit files."
154+
155+
# Check config to see if it is 4bit
156+
config_path = os.path.join(save_path, "config.json")
157+
with open(config_path, "r") as f:
158+
config = json.load(f)
159+
160+
assert "quantization_config" in config, "Quantization config not found in the model config."
161+
162+
# Test loading the model from the saved path
163+
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
164+
save_path,
165+
max_seq_length=128,
166+
dtype=None,
167+
load_in_4bit=True,
168+
)
169+

unsloth/save.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2301,6 +2301,17 @@ def unsloth_generic_save(
23012301
maximum_memory_usage : float = 0.9,
23022302
):
23032303
if token is None and push_to_hub: token = get_token()
2304+
2305+
if save_method == "merged_4bit":
2306+
raise RuntimeError(
2307+
"Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
2308+
"to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
2309+
"if you're planning to do multiple saves.\n"\
2310+
"If you are certain, change `save_method` to `merged_4bit_forced`."
2311+
)
2312+
elif save_method == "merged_4bit_forced":
2313+
save_method = "merged_4bit"
2314+
23042315
merge_and_overwrite_lora(
23052316
get_model_name,
23062317
model = model,
@@ -2309,6 +2320,7 @@ def unsloth_generic_save(
23092320
push_to_hub = push_to_hub,
23102321
private = private,
23112322
token = token,
2323+
save_method = save_method,
23122324
output_dtype = None,
23132325
low_disk_space_usage = True,
23142326
use_temp_file = False,

0 commit comments

Comments
 (0)