Skip to content

Commit b105c55

Browse files
Georgedsikka
andauthored
[Test Fix] Add Quantization then finetune tests (#964)
~~Contingent on merge of huggingface/transformers#34719 ^ has been merged not yet released SUMMARY: Add test to * Given a model, oneshot quantize, then run ptq - training. Model must be run_compressed = False to run Note: * When running finetune on an already optimized (one-shotted) mode, the model needs to be decompressed explicitly using `CompressedTensorsConfig`. See https://github.com/vllm-project/llm-compressor/pull/964/files#diff-e480ed475c0a5b2beb4052c1dd2aca671999634ace41a5ea017fdff1ce68be0bR130-R135 * Tests using x2 H100s passed Also fix a bug where in log_sparsification, the layer name is not being recognized so fails. Here nothting is being sparsified, so num params is set to zero TEST PLAN: ran the test using transformers main must pass tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py --------- Co-authored-by: Dipika Sikka <[email protected]>
1 parent fb01d66 commit b105c55

File tree

2 files changed

+93
-24
lines changed

2 files changed

+93
-24
lines changed

src/llmcompressor/pytorch/utils/sparsification.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,14 @@ def params_quantized(self) -> int:
105105
"""
106106
:return: number of parameters across quantized layers
107107
"""
108-
return sum(
109-
torch.numel(self.trainable_params[f"{name}.weight"])
110-
+ (
111-
torch.numel(self.trainable_params[f"{name}.bias"])
112-
if hasattr(layer, "bias") and layer.bias is not None
113-
else 0
114-
)
115-
for (name, layer) in get_quantized_layers(self.module)
116-
)
108+
num_params = 0
109+
for name, layer in get_quantized_layers(self.module):
110+
if getattr(layer, "weight", None) is not None:
111+
num_params += torch.numel(layer.weight)
112+
if getattr(layer, "bias", None) is not None:
113+
num_params += torch.numel(layer.bias)
114+
115+
return num_params
117116

118117
@property
119118
def params_quantized_percent(self) -> float:

tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,23 @@
1-
import os
21
import shutil
32
import unittest
43
from pathlib import Path
54

65
import pytest
6+
from transformers import AutoModelForCausalLM
7+
from transformers.utils.quantization_config import CompressedTensorsConfig
8+
9+
from llmcompressor.core import create_session
10+
from llmcompressor.modifiers.quantization import QuantizationModifier
11+
from llmcompressor.transformers import oneshot, train
712

813

914
@pytest.mark.unit
10-
@pytest.mark.skipif(
11-
"CADENCE" in os.environ
12-
and (os.environ["CADENCE"] == "weekly" or os.environ["CADENCE"] == "nightly"),
13-
reason="Don't run for weekly and nightly tests as those use multi gpu "
14-
"runners and this test fails when ngpu>1",
15-
)
1615
class TestOneshotThenFinetune(unittest.TestCase):
1716
def setUp(self):
1817
self.output = Path("./finetune_output")
18+
self.quantization_config = CompressedTensorsConfig(run_compressed=False)
1919

20-
def test_oneshot_then_finetune(self):
21-
from transformers import AutoModelForCausalLM
22-
23-
from llmcompressor.core import create_session
24-
from llmcompressor.transformers import oneshot, train
25-
20+
def test_oneshot_sparsification_then_finetune(self):
2621
recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
2722
model = AutoModelForCausalLM.from_pretrained(
2823
"Xenova/llama2.c-stories15M", device_map="auto"
@@ -47,8 +42,12 @@ def test_oneshot_then_finetune(self):
4742
recipe_str = (
4843
"tests/llmcompressor/transformers/finetune/test_finetune_recipe.yaml"
4944
)
45+
46+
# Explictly decompress the model for training using quantization_config
5047
model = AutoModelForCausalLM.from_pretrained(
51-
self.output / "oneshot_out", device_map="auto"
48+
self.output / "oneshot_out",
49+
device_map="auto",
50+
quantization_config=self.quantization_config,
5251
)
5352
distill_teacher = AutoModelForCausalLM.from_pretrained(
5453
"Xenova/llama2.c-stories15M", device_map="auto"
@@ -73,7 +72,12 @@ def test_oneshot_then_finetune(self):
7372
)
7473

7574
# test reloading checkpoint and final model
76-
model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto")
75+
# verify checkpoint reloading and can carry out finetune
76+
# with the saved model
77+
# Explictly decompress the model for training using quantization_config
78+
model = AutoModelForCausalLM.from_pretrained(
79+
output_dir, device_map="auto", quantization_config=self.quantization_config
80+
)
7781
with create_session():
7882
train(
7983
model=model,
@@ -88,5 +92,71 @@ def test_oneshot_then_finetune(self):
8892
resume_from_checkpoint=True, # use last checkpoint
8993
)
9094

95+
def test_oneshot_quantization_then_finetune(self):
96+
recipe = QuantizationModifier(
97+
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
98+
)
99+
100+
model = AutoModelForCausalLM.from_pretrained(
101+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
102+
device_map="auto",
103+
)
104+
dataset = "open_platypus"
105+
concatenate_data = False
106+
num_calibration_samples = 64
107+
output_dir = self.output / "oneshot_out"
108+
splits = {"calibration": "train[:10%]"}
109+
110+
with create_session():
111+
oneshot(
112+
model=model,
113+
dataset=dataset,
114+
output_dir=output_dir,
115+
num_calibration_samples=num_calibration_samples,
116+
recipe=recipe,
117+
concatenate_data=concatenate_data,
118+
splits=splits,
119+
)
120+
121+
from transformers.utils.quantization_config import CompressedTensorsConfig
122+
123+
quantization_config = CompressedTensorsConfig(run_compressed=False)
124+
model = AutoModelForCausalLM.from_pretrained(
125+
output_dir,
126+
device_map="auto",
127+
quantization_config=quantization_config,
128+
)
129+
dataset = "open_platypus"
130+
concatenate_data = False
131+
output_dir = self.output / "finetune_out"
132+
splits = {"calibration": "train[:10%]", "train": "train[:10%]"}
133+
134+
with create_session():
135+
train(
136+
model=model,
137+
dataset=dataset,
138+
output_dir=output_dir,
139+
num_calibration_samples=num_calibration_samples,
140+
recipe=recipe,
141+
concatenate_data=concatenate_data,
142+
splits=splits,
143+
)
144+
145+
# test reloading checkpoint and final model
146+
model = AutoModelForCausalLM.from_pretrained(
147+
output_dir, device_map="auto", quantization_config=quantization_config
148+
)
149+
with create_session():
150+
train(
151+
model=model,
152+
dataset=dataset,
153+
output_dir=output_dir,
154+
num_calibration_samples=num_calibration_samples,
155+
recipe=recipe,
156+
concatenate_data=concatenate_data,
157+
splits=splits,
158+
resume_from_checkpoint=True, # use last checkpoint
159+
)
160+
91161
def tearDown(self):
92162
shutil.rmtree(self.output)

0 commit comments

Comments
 (0)