Skip to content

Commit 4f20e54

Browse files
Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990)
This reverts commit 204fc46.
1 parent 2682c9b commit 4f20e54

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from unsloth import FastVisionModel
4+
5+
import torch
6+
from qwen_vl_utils import process_vision_info
7+
import os
8+
from datasets import load_dataset
9+
from trl import SFTTrainer, SFTConfig
10+
11+
import sys
12+
from pathlib import Path
13+
14+
15+
REPO_ROOT = Path(__file__).parents[3]
16+
sys.path.insert(0, str(REPO_ROOT))
17+
18+
from tests.utils.cleanup_utils import safe_remove_directory
19+
from tests.utils.ocr_eval import OCRModelEvaluator
20+
21+
22+
## Dataset Preparation
23+
from datasets import load_dataset
24+
25+
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", 'en', split="train")
26+
# To select the first 2000 examples
27+
train_dataset = dataset.select(range(2000))
28+
29+
# To select the next 200 examples for evaluation
30+
eval_dataset = dataset.select(range(2000, 2200))
31+
32+
# Convert dataset to OAI messages
33+
def format_data(sample):
34+
return {"messages": [
35+
{
36+
"role": "system",
37+
"content": [{"type": "text", "text": system_message}],
38+
},
39+
{
40+
"role": "user",
41+
"content": [
42+
{
43+
"type": "text",
44+
"text": sample["question"],
45+
},{
46+
"type": "image",
47+
"image": sample["image"],
48+
}
49+
],
50+
},
51+
{
52+
"role": "assistant",
53+
"content": [{"type": "text", "text": sample["answer"]}],
54+
},
55+
],
56+
}
57+
58+
system_message = "You are an expert french ocr system."
59+
# Convert dataset to OAI messages
60+
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
61+
train_dataset = [format_data(sample) for sample in train_dataset]
62+
eval_dataset = [format_data(sample) for sample in eval_dataset]
63+
64+
## Setup OCR main evaluation function and helpers
65+
import os
66+
import torch
67+
from tqdm import tqdm
68+
import pandas as pd
69+
from jiwer import wer, cer
70+
from qwen_vl_utils import process_vision_info
71+
72+
#
73+
ocr_evaluator = OCRModelEvaluator()
74+
model_comparison_results = {}
75+
76+
## Finetuning Setup and Run
77+
# Load Base Model
78+
79+
model, tokenizer = FastVisionModel.from_pretrained(
80+
model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
81+
max_seq_length = 2048, # Choose any for long context!
82+
load_in_4bit = True, # 4 bit quantization to reduce memory
83+
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
84+
full_finetuning = False, # [NEW!] We have full finetuning now!
85+
)
86+
87+
# benchmark base model performance
88+
model_name = "Unsloth Base model"
89+
FastVisionModel.for_inference(model)
90+
avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results")
91+
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
92+
93+
## Lora Finetuning
94+
model = FastVisionModel.get_peft_model(
95+
model,
96+
finetune_vision_layers = True, # Turn off for just text!
97+
finetune_language_layers = True, # Should leave on!
98+
finetune_attention_modules = True, # Attention good for GRPO
99+
finetune_mlp_modules = True, # SHould leave on always!
100+
101+
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
102+
#target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
103+
#"gate_proj", "up_proj", "down_proj",],
104+
lora_alpha = 32,
105+
lora_dropout = 0, # Supports any, but = 0 is optimized
106+
bias = "none", # Supports any, but = "none" is optimized
107+
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
108+
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
109+
random_state = 3407,
110+
use_rslora = False, # We support rank stabilized LoRA
111+
loftq_config = None, # And LoftQ
112+
)
113+
114+
from unsloth import is_bf16_supported
115+
from unsloth.trainer import UnslothVisionDataCollator
116+
FastVisionModel.for_training(model) # Enable for training!
117+
model.config.use_cache = False
118+
119+
120+
trainer = SFTTrainer(
121+
model = model,
122+
tokenizer = tokenizer,
123+
data_collator = UnslothVisionDataCollator(model, tokenizer),
124+
train_dataset = train_dataset,
125+
args = SFTConfig(
126+
#per_device_train_batch_size = 4,
127+
#gradient_accumulation_steps = 8,
128+
per_device_train_batch_size = 2,
129+
gradient_accumulation_steps = 4,
130+
gradient_checkpointing=True,
131+
gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
132+
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
133+
warmup_ratio=0.03,
134+
#num_train_epochs = 2, # Set this instead of max_steps for full training runs
135+
max_steps=60,
136+
learning_rate = 2e-4,
137+
fp16 = not is_bf16_supported(),
138+
bf16 = is_bf16_supported(),
139+
logging_steps = 5,
140+
save_strategy="epoch",
141+
optim = "adamw_torch_fused",
142+
weight_decay = 0.01,
143+
lr_scheduler_type = "linear",
144+
seed = 3407,
145+
output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
146+
report_to = "none", # For Weights and Biases
147+
148+
# You MUST put the below items for vision finetuning:
149+
remove_unused_columns = False,
150+
dataset_text_field = "",
151+
dataset_kwargs = {"skip_prepare_dataset": True},
152+
dataset_num_proc = 4,
153+
max_seq_length = 2048,
154+
),
155+
)
156+
157+
# run training
158+
trainer_stats = trainer.train()
159+
160+
model.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter", tokenizer)
161+
tokenizer.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter")
162+
163+
## Measure Adapter Performance
164+
165+
# benchmark lora model performance
166+
model_name = "Unsloth lora adapter model"
167+
FastVisionModel.for_inference(model)
168+
avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results")
169+
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
170+
171+
## Merge Model
172+
173+
def find_lora_base_model(model_to_inspect):
174+
current = model_to_inspect
175+
if hasattr(current, "base_model"):
176+
current = current.base_model
177+
if hasattr(current, "model"):
178+
current = current.model
179+
return current
180+
pass
181+
182+
base = find_lora_base_model(model)
183+
184+
print((base.__class__.__name__))
185+
186+
# merge default 16 bits
187+
model.save_pretrained_merged(save_directory="qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer)
188+
189+
190+
## Benchmark merged model performance
191+
192+
### 16 bits merged model
193+
194+
model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=False)
195+
196+
# benchmark 4bit loaded, 16bits merged model performance
197+
model_name = "Unsloth 16bits-merged model load-16bits"
198+
model.config.use_cache = True
199+
200+
avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_16bits_results")
201+
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
202+
203+
# load 16bits-merged model in 4 bits
204+
model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=True, load_in_8bit=False)
205+
206+
# benchmark 4bit loaded, 16bits merged model performance
207+
model_name = "Unsloth 16bits-merged model load-4bits"
208+
model.config.use_cache = True
209+
210+
avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_4bits_results")
211+
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
212+
213+
# load model in 8 bits
214+
model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=True)
215+
216+
# benchmark 4bit loaded, 16bits merged model performance
217+
model_name = "Unsloth 16bits-merged model load-8bits"
218+
avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_8bits_results")
219+
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
220+
221+
# """### 4 bits merged model"""
222+
#
223+
# # load 4bits-merged model in 4 bits
224+
# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=True, load_in_8bit=False)
225+
#
226+
# # benchmark 4bit loaded, 4bits merged model performance
227+
# model_name = "Unsloth 4bits-merged model load-4bits"
228+
#
229+
# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_4bits_results")
230+
# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
231+
#
232+
# # load model in 8 bits
233+
# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=False, load_in_8bit=True)
234+
#
235+
# # benchmark 8bit loaded, 4bits merged model performance
236+
# model_name = "Unsloth 4bits-merged model load-8bits"
237+
#
238+
# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_8bits_results")
239+
# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
240+
241+
# Model comparison report
242+
#print model comparison
243+
ocr_evaluator.print_model_comparison()
244+
245+
246+
247+
# Final cleanup
248+
print("\n🧹 Cleaning up temporary files...")
249+
safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-adapter")
250+
safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-checkpoints")
251+
safe_remove_directory("./unsloth_compiled_cache")
252+
safe_remove_directory("./qwen2.5-ocr-merged-finetune-merge-16bit")
253+
254+
print("\n🎯 Pipeline completed successfully!")
255+
print("=" * 80)

unsloth/models/mapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,11 @@
618618
"Qwen/Qwen2.5-VL-7B-Instruct",
619619
"unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
620620
),
621+
"unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit" : (
622+
"unsloth/Qwen2.5-VL-32B-Instruct",
623+
"Qwen/Qwen2.5-VL-32B-Instruct",
624+
"unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
625+
),
621626
"unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : (
622627
"unsloth/Qwen2.5-VL-72B-Instruct",
623628
"Qwen/Qwen2.5-VL-72B-Instruct",

0 commit comments

Comments
 (0)