Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/saving/gpt-oss-merge/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
set -e

echo "================================================================"
echo "🚀 STEP 1: Running the training and merging script..."
echo "================================================================"
python train_and_merge.py

echo ""
echo "================================================================"
echo "✅ STEP 2: Training complete. Running the inference script..."
echo "================================================================"
python test_merged_model.py

echo ""
echo "================================================================"
echo "🎉 All steps completed successfully!"
echo "================================================================"
55 changes: 55 additions & 0 deletions tests/saving/gpt-oss-merge/test_merged_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# inference_on_merged.py
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
import gc
import os
import shutil

def safe_remove_directory(path):
try:
if os.path.exists(path) and os.path.isdir(path):
shutil.rmtree(path)
return True
else:
print(f"Path {path} is not a valid directory")
return False
except Exception as e:
print(f"Failed to remove directory {path}: {e}")
return False
pass

print("🔥 Loading the 16-bit merged model from disk...")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./gpt-oss-finetuned-merged",
max_seq_length=1024,
load_in_4bit=True,
load_in_8bit=False,
)
print("✅ Merged model loaded successfully.")

# --- Run Inference ---
print("\n🚀 Running inference...")
messages = [
{"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
]
inputs = merged_tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
return_tensors = "pt",
return_dict = True,
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
).to(merged_model.device)

_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
print("\n✅ Inference complete.")

# --- Final Cleanup ---
print("\n🧹 Cleaning up merged model directory and cache...")
del merged_model, merged_tokenizer
torch.cuda.empty_cache()
gc.collect()

safe_remove_directory("./gpt-oss-finetuned-merged")
safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
print("✅ Final cleanup complete. Exiting inference script.")
71 changes: 71 additions & 0 deletions tests/saving/gpt-oss-merge/train_and_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# train_and_merge.py
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import torch
import gc
import os
import shutil

def safe_remove_directory(path):
try:
if os.path.exists(path) and os.path.isdir(path):
shutil.rmtree(path)
return True
else:
print(f"Path {path} is not a valid directory")
return False
except Exception as e:
print(f"Failed to remove directory {path}: {e}")
return False
pass

# This tokenizer will be used by the mapping function
tokenizer = None
def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}

# --- Load 4-bit Model and Train ---
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
max_seq_length = 1024
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
)

dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
formatting_prompts_func, batched=True
)

model = FastLanguageModel.get_peft_model(
model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
)

trainer = SFTTrainer(
model=model, tokenizer=tokenizer, train_dataset=dataset,
args=SFTConfig(
per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
learning_rate=2e-4, output_dir="outputs", report_to="none",
),
)

print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning complete.")

# --- Merge and Save ---
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
print("✅ Model merged and saved.")

# --- Cleanup ---
print("\n🧹 Cleaning up training artifacts...")
del model, trainer, tokenizer, dataset
torch.cuda.empty_cache()
gc.collect()

safe_remove_directory("./outputs")
safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
print("✅ Cleanup complete. Exiting training script.")
223 changes: 223 additions & 0 deletions tests/saving/language_models/test_merge_4bit_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from trl import SFTTrainer, SFTConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from datasets import load_dataset
import torch
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).parents[3]
sys.path.insert(0, str(REPO_ROOT))

from tests.utils.cleanup_utils import safe_remove_directory

def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}

print(f"\n{'='*80}")
print("🔍 PHASE 1: Loading Base Model and Initial Training")
print(f"{'='*80}")

if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
attn_implementation = 'flash_attention_2'
else:
compute_dtype = torch.float16
attn_implementation = 'sdpa'

model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation
)

tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
)

# Load small dataset for quick training
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)

print("✅ Base model loaded successfully!")

print(f"\n{'='*80}")
print("🔍 PHASE 2: First Fine-tuning")
print(f"{'='*80}")

model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)

from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
),
)

trainer_stats = trainer.train()
print("✅ First fine-tuning completed!")

print(f"\n{'='*80}")
print("🔍 PHASE 3: Save with Forced 4bit Merge")
print(f"{'='*80}")

model.save_pretrained_merged(
save_directory='./test_4bit_model',
tokenizer=tokenizer,
save_method="forced_merged_4bit"
)

print("✅ Model saved with forced 4bit merge!")

print(f"\n{'='*80}")
print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
print(f"{'='*80}")

# Clean up first model
del model
del tokenizer
torch.cuda.empty_cache()

# Load the 4bit merged model
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
model_name="./test_4bit_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
)

tokenizer_4bit = get_chat_template(
tokenizer_4bit,
chat_template="llama-3.1",
)

print("✅ 4bit model loaded successfully!")

# Add LoRA adapters to the 4bit model
model_4bit = FastLanguageModel.get_peft_model(
model_4bit,
r=16,
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)

# Second fine-tuning
trainer_4bit = SFTTrainer(
model=model_4bit,
tokenizer=tokenizer_4bit,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs_4bit",
report_to="none",
),
)

trainer_4bit.train()
print("✅ Second fine-tuning on 4bit model completed!")

print(f"\n{'='*80}")
print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
print(f"{'='*80}")

try:
model_4bit.save_pretrained_merged(
save_directory='./test_should_fail',
tokenizer=tokenizer_4bit
# No save_method specified, should default to regular merge
)
assert False, "Expected TypeError but merge succeeded!"
except TypeError as e:
expected_error = "Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead"
assert expected_error in str(e), f"Unexpected error message: {str(e)}"
print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
print(f"Error message: {str(e)}")

print(f"\n{'='*80}")
print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
print(f"{'='*80}")

try:
model_4bit.save_pretrained_merged(
save_directory='./test_4bit_second',
tokenizer=tokenizer_4bit,
save_method="forced_merged_4bit"
)
print("✅ Successfully saved 4bit model with forced 4bit method!")
except Exception as e:
assert False, f"Phase 6 failed unexpectedly: {e}"

print(f"\n{'='*80}")
print("🔍 CLEANUP")
print(f"{'='*80}")

# Cleanup
safe_remove_directory("./outputs")
safe_remove_directory("./outputs_4bit")
safe_remove_directory("./unsloth_compiled_cache")
safe_remove_directory("./test_4bit_model")
safe_remove_directory("./test_4bit_second")
safe_remove_directory("./test_should_fail")

print("✅ All tests passed successfully!")