Skip to content

Commit 9050d6f

Browse files
tests for mxfp4 and quantized models merge fix unsloth zoo pr 254 (#3223)
1 parent b753ec0 commit 9050d6f

File tree

4 files changed

+367
-0
lines changed

4 files changed

+367
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
set -e
3+
4+
echo "================================================================"
5+
echo "🚀 STEP 1: Running the training and merging script..."
6+
echo "================================================================"
7+
python train_and_merge.py
8+
9+
echo ""
10+
echo "================================================================"
11+
echo "✅ STEP 2: Training complete. Running the inference script..."
12+
echo "================================================================"
13+
python test_merged_model.py
14+
15+
echo ""
16+
echo "================================================================"
17+
echo "🎉 All steps completed successfully!"
18+
echo "================================================================"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# inference_on_merged.py
2+
from unsloth import FastLanguageModel
3+
from transformers import TextStreamer
4+
import torch
5+
import gc
6+
import os
7+
import shutil
8+
9+
def safe_remove_directory(path):
10+
try:
11+
if os.path.exists(path) and os.path.isdir(path):
12+
shutil.rmtree(path)
13+
return True
14+
else:
15+
print(f"Path {path} is not a valid directory")
16+
return False
17+
except Exception as e:
18+
print(f"Failed to remove directory {path}: {e}")
19+
return False
20+
pass
21+
22+
print("🔥 Loading the 16-bit merged model from disk...")
23+
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
24+
model_name="./gpt-oss-finetuned-merged",
25+
max_seq_length=1024,
26+
load_in_4bit=True,
27+
load_in_8bit=False,
28+
)
29+
print("✅ Merged model loaded successfully.")
30+
31+
# --- Run Inference ---
32+
print("\n🚀 Running inference...")
33+
messages = [
34+
{"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
35+
]
36+
inputs = merged_tokenizer.apply_chat_template(
37+
messages,
38+
add_generation_prompt = True,
39+
return_tensors = "pt",
40+
return_dict = True,
41+
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
42+
).to(merged_model.device)
43+
44+
_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
45+
print("\n✅ Inference complete.")
46+
47+
# --- Final Cleanup ---
48+
print("\n🧹 Cleaning up merged model directory and cache...")
49+
del merged_model, merged_tokenizer
50+
torch.cuda.empty_cache()
51+
gc.collect()
52+
53+
safe_remove_directory("./gpt-oss-finetuned-merged")
54+
safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
55+
print("✅ Final cleanup complete. Exiting inference script.")
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# train_and_merge.py
2+
from unsloth import FastLanguageModel
3+
from trl import SFTTrainer, SFTConfig
4+
from datasets import load_dataset
5+
import torch
6+
import gc
7+
import os
8+
import shutil
9+
10+
def safe_remove_directory(path):
11+
try:
12+
if os.path.exists(path) and os.path.isdir(path):
13+
shutil.rmtree(path)
14+
return True
15+
else:
16+
print(f"Path {path} is not a valid directory")
17+
return False
18+
except Exception as e:
19+
print(f"Failed to remove directory {path}: {e}")
20+
return False
21+
pass
22+
23+
# This tokenizer will be used by the mapping function
24+
tokenizer = None
25+
def formatting_prompts_func(examples):
26+
convos = examples["messages"]
27+
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
28+
return {"text": texts}
29+
30+
# --- Load 4-bit Model and Train ---
31+
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
32+
max_seq_length = 1024
33+
model, tokenizer = FastLanguageModel.from_pretrained(
34+
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
35+
)
36+
37+
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
38+
formatting_prompts_func, batched=True
39+
)
40+
41+
model = FastLanguageModel.get_peft_model(
42+
model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
43+
lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
44+
)
45+
46+
trainer = SFTTrainer(
47+
model=model, tokenizer=tokenizer, train_dataset=dataset,
48+
args=SFTConfig(
49+
per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
50+
learning_rate=2e-4, output_dir="outputs", report_to="none",
51+
),
52+
)
53+
54+
print("Starting fine-tuning...")
55+
trainer.train()
56+
print("Fine-tuning complete.")
57+
58+
# --- Merge and Save ---
59+
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
60+
model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
61+
print("✅ Model merged and saved.")
62+
63+
# --- Cleanup ---
64+
print("\n🧹 Cleaning up training artifacts...")
65+
del model, trainer, tokenizer, dataset
66+
torch.cuda.empty_cache()
67+
gc.collect()
68+
69+
safe_remove_directory("./outputs")
70+
safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
71+
print("✅ Cleanup complete. Exiting training script.")
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
from unsloth import FastLanguageModel
2+
from unsloth.chat_templates import get_chat_template
3+
from trl import SFTTrainer, SFTConfig
4+
from transformers import DataCollatorForSeq2Seq, TrainingArguments
5+
from datasets import load_dataset
6+
import torch
7+
import sys
8+
from pathlib import Path
9+
10+
REPO_ROOT = Path(__file__).parents[3]
11+
sys.path.insert(0, str(REPO_ROOT))
12+
13+
from tests.utils.cleanup_utils import safe_remove_directory
14+
15+
def formatting_prompts_func(examples):
16+
convos = examples["messages"]
17+
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
18+
return {"text": texts}
19+
20+
print(f"\n{'='*80}")
21+
print("🔍 PHASE 1: Loading Base Model and Initial Training")
22+
print(f"{'='*80}")
23+
24+
if torch.cuda.is_bf16_supported():
25+
compute_dtype = torch.bfloat16
26+
attn_implementation = 'flash_attention_2'
27+
else:
28+
compute_dtype = torch.float16
29+
attn_implementation = 'sdpa'
30+
31+
model, tokenizer = FastLanguageModel.from_pretrained(
32+
model_name="unsloth/Llama-3.1-8B-Instruct",
33+
max_seq_length=2048,
34+
dtype=compute_dtype,
35+
load_in_4bit=True,
36+
load_in_8bit=False,
37+
full_finetuning=False,
38+
attn_implementation=attn_implementation
39+
)
40+
41+
tokenizer = get_chat_template(
42+
tokenizer,
43+
chat_template="llama-3.1",
44+
)
45+
46+
# Load small dataset for quick training
47+
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
48+
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
49+
50+
print("✅ Base model loaded successfully!")
51+
52+
print(f"\n{'='*80}")
53+
print("🔍 PHASE 2: First Fine-tuning")
54+
print(f"{'='*80}")
55+
56+
model = FastLanguageModel.get_peft_model(
57+
model,
58+
r=16,
59+
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
60+
lora_alpha=16,
61+
lora_dropout=0,
62+
bias="none",
63+
use_gradient_checkpointing="unsloth",
64+
random_state=3407,
65+
use_rslora=False,
66+
loftq_config=None,
67+
)
68+
69+
from unsloth import is_bfloat16_supported
70+
71+
trainer = SFTTrainer(
72+
model=model,
73+
tokenizer=tokenizer,
74+
train_dataset=dataset_train,
75+
dataset_text_field="text",
76+
max_seq_length=2048,
77+
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
78+
dataset_num_proc=2,
79+
packing=False,
80+
args=TrainingArguments(
81+
per_device_train_batch_size=2,
82+
gradient_accumulation_steps=4,
83+
warmup_ratio=0.1,
84+
max_steps=10, # Very short training for test
85+
learning_rate=2e-4,
86+
fp16=not is_bfloat16_supported(),
87+
bf16=is_bfloat16_supported(),
88+
logging_steps=5,
89+
optim="adamw_8bit",
90+
lr_scheduler_type="linear",
91+
seed=3407,
92+
output_dir="outputs",
93+
report_to="none",
94+
),
95+
)
96+
97+
trainer_stats = trainer.train()
98+
print("✅ First fine-tuning completed!")
99+
100+
print(f"\n{'='*80}")
101+
print("🔍 PHASE 3: Save with Forced 4bit Merge")
102+
print(f"{'='*80}")
103+
104+
model.save_pretrained_merged(
105+
save_directory='./test_4bit_model',
106+
tokenizer=tokenizer,
107+
save_method="forced_merged_4bit"
108+
)
109+
110+
print("✅ Model saved with forced 4bit merge!")
111+
112+
print(f"\n{'='*80}")
113+
print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
114+
print(f"{'='*80}")
115+
116+
# Clean up first model
117+
del model
118+
del tokenizer
119+
torch.cuda.empty_cache()
120+
121+
# Load the 4bit merged model
122+
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
123+
model_name="./test_4bit_model",
124+
max_seq_length=2048,
125+
load_in_4bit=True,
126+
load_in_8bit=False,
127+
)
128+
129+
tokenizer_4bit = get_chat_template(
130+
tokenizer_4bit,
131+
chat_template="llama-3.1",
132+
)
133+
134+
print("✅ 4bit model loaded successfully!")
135+
136+
# Add LoRA adapters to the 4bit model
137+
model_4bit = FastLanguageModel.get_peft_model(
138+
model_4bit,
139+
r=16,
140+
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
141+
lora_alpha=16,
142+
lora_dropout=0,
143+
bias="none",
144+
use_gradient_checkpointing="unsloth",
145+
random_state=3407,
146+
use_rslora=False,
147+
loftq_config=None,
148+
)
149+
150+
# Second fine-tuning
151+
trainer_4bit = SFTTrainer(
152+
model=model_4bit,
153+
tokenizer=tokenizer_4bit,
154+
train_dataset=dataset_train,
155+
dataset_text_field="text",
156+
max_seq_length=2048,
157+
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
158+
dataset_num_proc=2,
159+
packing=False,
160+
args=TrainingArguments(
161+
per_device_train_batch_size=2,
162+
gradient_accumulation_steps=4,
163+
warmup_ratio=0.1,
164+
max_steps=10, # Very short training for test
165+
learning_rate=2e-4,
166+
fp16=not is_bfloat16_supported(),
167+
bf16=is_bfloat16_supported(),
168+
logging_steps=5,
169+
optim="adamw_8bit",
170+
lr_scheduler_type="linear",
171+
seed=3407,
172+
output_dir="outputs_4bit",
173+
report_to="none",
174+
),
175+
)
176+
177+
trainer_4bit.train()
178+
print("✅ Second fine-tuning on 4bit model completed!")
179+
180+
print(f"\n{'='*80}")
181+
print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
182+
print(f"{'='*80}")
183+
184+
try:
185+
model_4bit.save_pretrained_merged(
186+
save_directory='./test_should_fail',
187+
tokenizer=tokenizer_4bit
188+
# No save_method specified, should default to regular merge
189+
)
190+
assert False, "Expected TypeError but merge succeeded!"
191+
except TypeError as e:
192+
expected_error = "Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead"
193+
assert expected_error in str(e), f"Unexpected error message: {str(e)}"
194+
print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
195+
print(f"Error message: {str(e)}")
196+
197+
print(f"\n{'='*80}")
198+
print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
199+
print(f"{'='*80}")
200+
201+
try:
202+
model_4bit.save_pretrained_merged(
203+
save_directory='./test_4bit_second',
204+
tokenizer=tokenizer_4bit,
205+
save_method="forced_merged_4bit"
206+
)
207+
print("✅ Successfully saved 4bit model with forced 4bit method!")
208+
except Exception as e:
209+
assert False, f"Phase 6 failed unexpectedly: {e}"
210+
211+
print(f"\n{'='*80}")
212+
print("🔍 CLEANUP")
213+
print(f"{'='*80}")
214+
215+
# Cleanup
216+
safe_remove_directory("./outputs")
217+
safe_remove_directory("./outputs_4bit")
218+
safe_remove_directory("./unsloth_compiled_cache")
219+
safe_remove_directory("./test_4bit_model")
220+
safe_remove_directory("./test_4bit_second")
221+
safe_remove_directory("./test_should_fail")
222+
223+
print("✅ All tests passed successfully!")

0 commit comments

Comments
 (0)