Skip to content

Commit 00c64bc

Browse files
Georgekylesayrs
authored andcommitted
[Test Fix] Fix/update test_run_compressed (#970)
~~Contingent on merge of huggingface/transformers#34719 ^ has been merged not yet released SUMMARY: Update run_compressed tests from decompression tests to run_comrpressed tests -> test if run_compressed True/False models generate the same output Add decompress tests that copies attrs from the source dir path's model to the target model. TEST PLAN: ran the test using transformers main must pass tests/llmcompressor/transformers/compression/test_decompress.py and tests/llmcompressor/transformers/compression/test_run_compressed.py Signed-off-by: Kyle Sayers <[email protected]>
1 parent ab1b144 commit 00c64bc

File tree

11 files changed

+299
-45
lines changed

11 files changed

+299
-45
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "commit"
2+
test_type: "regression"
3+
compressed_model_stub: "nm-testing/tinyllama-fp8-dynamic-compressed"
4+
skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "commit"
2+
test_type: "regression"
3+
compressed_model_stub: "nm-testing/tinyllama-w4a16-compressed"
4+
skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "commit"
2+
test_type: "regression"
3+
compressed_model_stub: "nm-testing/tinyllama-w8a16-dense"
4+
skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "commit"
2+
test_type: "regression"
3+
compressed_model_stub: "nm-testing/tinyllama-w8a8-compressed"
4+
skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cadence: "commit"
22
test_type: "regression"
3-
model_stub: "nm-testing/tinyllama-fp8-dynamic-compressed"
4-
empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
3+
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-compressed
4+
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-uncompressed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cadence: "commit"
22
test_type: "regression"
3-
model_stub: "nm-testing/tinyllama-w4a16-compressed"
4-
empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
3+
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-compressed
4+
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-uncompressed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "commit"
2+
test_type: "regression"
3+
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A16-G128-compressed
4+
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A16-G128-uncompressed

tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml

Lines changed: 0 additions & 4 deletions
This file was deleted.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cadence: "commit"
22
test_type: "regression"
3-
model_stub: "nm-testing/tinyllama-w8a8-compressed"
4-
empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
3+
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-compressed
4+
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-uncompressed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import copy
2+
import shutil
3+
import tempfile
4+
import unittest
5+
6+
import torch
7+
from compressed_tensors import QUANTIZATION_CONFIG_NAME
8+
from compressed_tensors.compressors import ModelCompressor
9+
from compressed_tensors.quantization import QuantizationStatus
10+
from parameterized import parameterized_class
11+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
12+
from transformers.utils.quantization_config import CompressedTensorsConfig
13+
14+
from tests.testing_utils import parse_params, requires_gpu
15+
16+
CONFIG_DIR = "tests/llmcompressor/transformers/compression/decompression_configs"
17+
18+
19+
@requires_gpu
20+
@parameterized_class(parse_params(CONFIG_DIR))
21+
class TestDecompression(unittest.TestCase):
22+
"""
23+
Check that HFQuantizer decompression is working as expected.
24+
Manually decompress a compressed model and compare the generations
25+
26+
Decompression:
27+
Given a skeleton model and path to the optimized model,
28+
write the optimized model's safetensors to the skeleton model and decompress
29+
Ex. write weight_scale to the skeleton model and then convert from fp4 to fp16
30+
31+
"""
32+
33+
compressed_model_stub = None
34+
skeleton_model_stub = None
35+
36+
SAMPLE_INPUTS = [
37+
"I love 4-bit quantization because",
38+
"What is the capital of France?",
39+
"def fibonacci(n):",
40+
]
41+
42+
@classmethod
43+
def setUpClass(self):
44+
self.test_dir = tempfile.mkdtemp()
45+
self.tokenizer = AutoTokenizer.from_pretrained(self.compressed_model_stub)
46+
47+
# Decompress using HFQuantizer from AutoModelForCausalLM
48+
self.decompressed_model_hf_quantizer = AutoModelForCausalLM.from_pretrained(
49+
self.compressed_model_stub,
50+
torch_dtype="auto",
51+
device_map="auto",
52+
quantization_config=CompressedTensorsConfig(run_compressed=False),
53+
)
54+
55+
# Manually decompress this model
56+
self.dense_model = AutoModelForCausalLM.from_pretrained(
57+
self.skeleton_model_stub,
58+
torch_dtype=self.decompressed_model_hf_quantizer.dtype,
59+
device_map=self.decompressed_model_hf_quantizer.device,
60+
)
61+
62+
assert not hasattr(
63+
self.dense_model.model.layers[0].self_attn.q_proj, "weight_scale"
64+
)
65+
66+
config = AutoConfig.from_pretrained(self.compressed_model_stub)
67+
68+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
69+
self.compressor = ModelCompressor.from_compression_config(compression_config)
70+
self.compressor.quantization_config.quantization_status = (
71+
QuantizationStatus.FROZEN
72+
)
73+
74+
# use the model_path to load the decompressed weights into dense_model
75+
dense_model = copy.deepcopy(self.dense_model)
76+
77+
# overwrite the weights of the dense model
78+
self.compressor.decompress(
79+
model_path=self.compressed_model_stub,
80+
model=self.dense_model,
81+
)
82+
83+
# self.dense_model should be decompressed
84+
assert dense_model is not self.dense_model
85+
86+
self.decompressed_model_manual = self.dense_model
87+
88+
assert hasattr(
89+
self.decompressed_model_manual.model.layers[0].self_attn.q_proj,
90+
"weight_scale",
91+
)
92+
93+
def test_hf_quantizer_decompress_match_manual_decompress(self):
94+
decompressed_model_manual = self.decompressed_model_manual.device
95+
decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.device
96+
97+
self.decompressed_model_manual = self.decompressed_model_manual.to(
98+
decompressed_model_manual
99+
)
100+
self.decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.to(
101+
decompressed_model_hf_quantizer
102+
)
103+
104+
for input in self.SAMPLE_INPUTS:
105+
inputs = self.tokenizer(input, return_tensors="pt", padding=True).to(
106+
self.decompressed_model_manual.device
107+
)
108+
inputs = inputs.to(self.decompressed_model_manual.device)
109+
110+
decompressed_model_manual_output = self.tokenizer.batch_decode(
111+
self.decompressed_model_manual.generate(**inputs, max_length=50)
112+
)
113+
114+
decompressed_model_hf_quantizer_out = self.tokenizer.batch_decode(
115+
self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50)
116+
)
117+
118+
assert (
119+
decompressed_model_hf_quantizer_out == decompressed_model_manual_output
120+
)
121+
122+
@classmethod
123+
def tearDownClass(self):
124+
shutil.rmtree(self.test_dir)
125+
del self.dense_model
126+
del self.decompressed_model_hf_quantizer
127+
del self.decompressed_model_manual
128+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)