Skip to content

Commit 56f2087

Browse files
committed
Fix QAT + LoRA fast path, add tests
**Summary:** The existing QAT + LoRA path only applied fake quantization to the original slow path, but the default is the fast path that calls unsloth's fast LoRA primitives. This commit integrates fake quantization into these fast primitives as well, and add unit tests to assert that fake quantization is actually taking place. **Test Plan:** Unit tests: ``` pytest tests/utils/test_qat.py ``` End-to-end test: https://gist.github.com/andrewor14/6360dd69b5784c71c46e80c14f53e6b6 Full fine-tuning Llama3.1-8B with and without QAT + LoRA on yahma/alpaca-cleaned for 1 epoch: - Batch size = 8 (no grad accum) - Learning rate = 2e-4 - Quantization scheme = int4 weight only (with bf16 activations) Wikitext perplexity: - Baseline = int4 quantized model finetuned without QAT - QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline - QAT int4 quantized model without this PR was worse than the int4 baseline ``` ==> unsloth_model_lora_baseline_output/lm_eval_float.log <== | | |none | 0|word_perplexity|↓ |7.5551|± | N/A| ==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.7655|± | N/A| ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.3548|± | N/A| ```
1 parent f06200d commit 56f2087

File tree

4 files changed

+208
-2
lines changed

4 files changed

+208
-2
lines changed

tests/utils/test_qat.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from unsloth import FastLanguageModel
2+
3+
from typing import Dict
4+
5+
import pytest
6+
import torch
7+
from torchao.quantization.qat import FakeQuantizedLinear
8+
from torchao.quantization.qat.fake_quantizer import (
9+
FakeQuantizerBase,
10+
Float8FakeQuantizer,
11+
Int4WeightPreshuffledFakeQuantizer,
12+
)
13+
14+
15+
class _CountingFakeQuantizer(torch.nn.Module):
16+
"""
17+
Dummy fake quantizer that counts the number of times it has been called.
18+
"""
19+
def __init__(self):
20+
super().__init__()
21+
self.count = 0
22+
23+
def forward(self, x: torch.Tensor) -> torch.Tensor:
24+
self.count += 1
25+
return x
26+
27+
28+
def _get_model(qat_scheme: str, full_finetuning: bool):
29+
"""
30+
Return a 2-tuple of (model, tokenizer), where the model has been configured
31+
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
32+
"""
33+
model, tokenizer = FastLanguageModel.from_pretrained(
34+
model_name = "unsloth/Qwen3-1.7B",
35+
load_in_4bit = False,
36+
full_finetuning = full_finetuning,
37+
qat_scheme = qat_scheme if full_finetuning else None,
38+
)
39+
if not full_finetuning:
40+
model = FastLanguageModel.get_peft_model(
41+
model,
42+
qat_scheme = qat_scheme,
43+
)
44+
return model, tokenizer
45+
46+
47+
def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
48+
"""
49+
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
50+
"""
51+
if qat_scheme == "fp8-int4":
52+
act_fq_class = Float8FakeQuantizer
53+
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
54+
min_in_features = 128
55+
elif qat_scheme == "fp8-fp8":
56+
act_fq_class = Float8FakeQuantizer
57+
weight_fq_class = Float8FakeQuantizer
58+
min_in_features = -1
59+
else:
60+
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
61+
62+
# Check base layer activations and weights
63+
base_layer = getattr(linear, "base_layer", linear)
64+
if base_layer.in_features >= min_in_features:
65+
assert isinstance(base_layer, FakeQuantizedLinear)
66+
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
67+
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
68+
69+
# Check lora A and B (only for full_finetuning=False)
70+
if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
71+
lora_A = linear.lora_A.default
72+
lora_B = linear.lora_B.default
73+
if lora_A.in_features >= min_in_features:
74+
assert isinstance(lora_A, FakeQuantizedLinear)
75+
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
76+
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
77+
if lora_B.in_features >= min_in_features:
78+
assert isinstance(lora_B, FakeQuantizedLinear)
79+
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
80+
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
81+
82+
83+
def _test_fake_quantizers_are_called(
84+
model: torch.nn.Module,
85+
example_inputs: Dict,
86+
full_finetuning: bool,
87+
):
88+
"""
89+
Verify that the fake quantizers are actually called when the model is called.
90+
"""
91+
def _swap_fake_quantizers(model: torch.nn.Module):
92+
for name, child in model.named_children():
93+
if isinstance(child, FakeQuantizerBase):
94+
setattr(model, name, _CountingFakeQuantizer())
95+
96+
def _assert_fake_quantizers_are_called(model: torch.nn.Module):
97+
for name, child in model.named_children():
98+
if full_finetuning:
99+
if isinstance(child, FakeQuantizedLinear):
100+
assert child.activation_fake_quantizer.count == 1
101+
assert child.weight_fake_quantizer.count == 1
102+
else:
103+
# For LoRA, we only fake quantize the input activations once per block:
104+
# For self_attn, we only fake quantize the q_proj's input activations
105+
# For mlp, we only fake quantize the gate_proj's input activations
106+
if name == "self_attn":
107+
base_layer = child.q_proj.base_layer
108+
assert hasattr(base_layer, "activation_fake_quantizer")
109+
assert base_layer.activation_fake_quantizer.count == 1
110+
elif name == "mlp":
111+
base_layer = child.gate_proj.base_layer
112+
assert hasattr(base_layer, "activation_fake_quantizer")
113+
assert base_layer.activation_fake_quantizer.count == 1
114+
elif isinstance(child, FakeQuantizedLinear):
115+
# Weight fake quantizers should always be called
116+
assert child.weight_fake_quantizer.count == 1
117+
118+
for k, v in example_inputs.items():
119+
example_inputs[k] = v.cuda()
120+
model.apply(_swap_fake_quantizers)
121+
model(**example_inputs)
122+
model.apply(_assert_fake_quantizers_are_called)
123+
124+
125+
def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
126+
"""
127+
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
128+
"""
129+
model, tokenizer = _get_model(qat_scheme, full_finetuning)
130+
if full_finetuning:
131+
model = model.model
132+
else:
133+
model = model.base_model.model.model
134+
for layer in model.layers:
135+
_test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
136+
_test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
137+
_test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
138+
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
139+
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
140+
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
141+
inputs = tokenizer("How are you?", return_tensors="pt")
142+
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
143+
144+
145+
# TODO: there are bad interactions across tests right now, need to figure out
146+
# how to disable model caching before re-enabling this test
147+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
148+
def _test_full_model_fake_quantize(qat_scheme: bool):
149+
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
150+
151+
152+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
153+
def test_lora_model_fake_quantize(qat_scheme: bool):
154+
_test_model_fake_quantize(qat_scheme, full_finetuning=False)

unsloth/kernels/fast_lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from .utils import (
17+
_maybe_fake_quantize_activations,
1718
fast_dequantize,
1819
QUANT_STATE,
1920
get_lora_parameters,
@@ -175,6 +176,7 @@ def backward(ctx, dY : torch.Tensor):
175176

176177
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
177178
def apply_lora_mlp_swiglu(self, X, inplace = True):
179+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
178180
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
179181
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
180182
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -190,6 +192,7 @@ def apply_lora_mlp_swiglu(self, X, inplace = True):
190192

191193
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
192194
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
195+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
193196
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
194197
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
195198
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -205,6 +208,7 @@ def apply_lora_mlp_geglu_exact(self, X, inplace = True):
205208

206209
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
207210
def apply_lora_mlp_geglu_approx(self, X):
211+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
208212
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
209213
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
210214
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -360,6 +364,7 @@ def backward(ctx, dQ, dK, dV):
360364

361365

362366
def apply_lora_qkv(self, X, inplace = True):
367+
X = _maybe_fake_quantize_activations(X, self.q_proj)
363368
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
364369
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
365370
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
@@ -453,6 +458,7 @@ def backward(ctx, dY : torch.Tensor):
453458

454459

455460
def apply_lora_o(self, X):
461+
X = _maybe_fake_quantize_activations(X, self.o_proj)
456462
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
457463
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
458464
return O

unsloth/kernels/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,19 @@ def cgemm_4bit_inference_naive_bf16(*args, **kwargs):
188188
def QUANT_STATE(W): return getattr(W, "quant_state", None)
189189

190190
def get_lora_parameters(proj):
191+
"""
192+
Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).
193+
If QAT is enabled, additionally fake quantize the base layer and lora weights.
194+
"""
191195
# For DPO or disabled adapters
192196
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
193197
W = base_layer.weight
194198

199+
# Optionally apply fake quantization to base layer weights for QAT
200+
weight_fake_quantizer = getattr(base_layer, "weight_fake_quantizer", None)
201+
if weight_fake_quantizer is not None:
202+
W = weight_fake_quantizer(W)
203+
195204
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
196205
if getattr(proj, "disable_adapters", True) or proj.merged:
197206
return W, getattr(W, "quant_state", None), None, None, None
@@ -201,11 +210,23 @@ def get_lora_parameters(proj):
201210
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
202211
adapter = adapter[0]
203212

213+
# Optionally apply fake quantization to lora weights for QAT
214+
lora_A_linear = proj.lora_A[adapter]
215+
lora_B_linear = proj.lora_B[adapter]
216+
lora_A_fake_quantizer = getattr(lora_A_linear, "weight_fake_quantizer", None)
217+
lora_B_fake_quantizer = getattr(lora_B_linear, "weight_fake_quantizer", None)
218+
A = lora_A_linear.weight
219+
B = lora_B_linear.weight
220+
if lora_A_fake_quantizer is not None:
221+
A = lora_A_fake_quantizer(A)
222+
if lora_B_fake_quantizer is not None:
223+
B = lora_B_fake_quantizer(B)
224+
204225
return (
205226
W,
206227
getattr(W, "quant_state", None),
207-
proj.lora_A [adapter].weight,
208-
proj.lora_B [adapter].weight,
228+
A,
229+
B,
209230
proj.scaling[adapter],
210231
)
211232
pass
@@ -235,6 +256,21 @@ def get_lora_parameters_bias(proj):
235256
)
236257
pass
237258

259+
260+
def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> torch.Tensor:
261+
"""
262+
If QAT is enabled, fake quantize the input activations.
263+
Otherwise, just return the input activations as is.
264+
Weights are fake quantized separately in `get_lora_parameters`.
265+
"""
266+
base_layer = getattr(proj, "base_layer", proj)
267+
activation_fake_quantizer = getattr(base_layer, "activation_fake_quantizer", None)
268+
if activation_fake_quantizer is not None:
269+
X = activation_fake_quantizer(X)
270+
return X
271+
pass
272+
273+
238274
# INTEL GPU Specific Logic
239275
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
240276
@torch.inference_mode

unsloth/models/_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,8 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
15441544
from torchao.quantization import (
15451545
Float8DynamicActivationInt4WeightConfig,
15461546
Float8DynamicActivationFloat8WeightConfig,
1547+
Int8DynamicActivationInt4WeightConfig,
1548+
Int4WeightOnlyConfig,
15471549
PerRow,
15481550
quantize_,
15491551
)
@@ -1555,6 +1557,14 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
15551557
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
15561558
elif qat_scheme == "fp8-fp8":
15571559
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
1560+
elif qat_scheme == "int8-int4":
1561+
group_size = 32
1562+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1563+
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
1564+
elif qat_scheme == "int4":
1565+
group_size = 128
1566+
base_config = Int4WeightOnlyConfig(group_size=group_size)
1567+
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
15581568
else:
15591569
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
15601570
pass

0 commit comments

Comments
 (0)