Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
03f7476
add meta info
ZX-ModelCloud Dec 4, 2024
2fbd311
cleanup
ZX-ModelCloud Dec 4, 2024
1087c77
cleanup
ZX-ModelCloud Dec 4, 2024
0b2b3e1
The value of quantizer should be an array
ZX-ModelCloud Dec 4, 2024
10bd2e4
Update quantizer.py
Qubitium Dec 4, 2024
86c9363
If is_auto_gptq_available() also writes "auto_gptq:version" to "quant…
ZX-ModelCloud Dec 4, 2024
9cce331
Merge remote-tracking branch 'origin/add_meta_info' into add_meta_info
ZX-ModelCloud Dec 4, 2024
007676d
If is_auto_gptq_available() also writes "auto_gptq:version" to "quant…
ZX-ModelCloud Dec 4, 2024
48e947a
Update quantizer.py
Qubitium Dec 4, 2024
180987d
cleanup
ZX-ModelCloud Dec 4, 2024
cb7e522
Merge remote-tracking branch 'origin/add_meta_info' into add_meta_info
ZX-ModelCloud Dec 4, 2024
c4a48c6
comment on meta
Qubitium Dec 4, 2024
aa92e66
hf_select_quant_linear pass checkpoint_format
ZX-ModelCloud Dec 4, 2024
ca354b3
add todo fix
Qubitium Dec 4, 2024
4d28581
move convert code to quantizer.save()
LRL-ModelCloud Dec 4, 2024
bdfc2b3
Update quantizer.py
Qubitium Dec 4, 2024
fb28f74
Optimize hf_convert_gptq_v2_to_v1_format()
ZX-ModelCloud Dec 4, 2024
216c1a6
Optimize hf_convert_gptq_v1_to_v2_format()
ZX-ModelCloud Dec 4, 2024
71e08f6
fix GPTQTestCUDA
ZX-ModelCloud Dec 4, 2024
20d5e8b
hf_select_quant_linear() always set pack=True
ZX-ModelCloud Dec 4, 2024
80bc085
gptqmodel.hf_select_quant_linear() now does not select ExllamaV2
ZX-ModelCloud Dec 4, 2024
caa499b
gptqmodel.hf_select_quant_linear() now does not select ExllamaV2
ZX-ModelCloud Dec 4, 2024
6acf8b9
GPTQQuantizer add backend
LRL-ModelCloud Dec 5, 2024
6a7a266
lower checkpoint_format and backend
LRL-ModelCloud Dec 5, 2024
88b2f99
cleanup
LRL-ModelCloud Dec 5, 2024
3679a42
move backend to bottom
LRL-ModelCloud Dec 5, 2024
1484ad4
no need to check gptqmodel version for ipex support
Qubitium Dec 5, 2024
6140129
Update import_utils.py
Qubitium Dec 5, 2024
71faf1a
Update quantizer.py
Qubitium Dec 5, 2024
bb754bc
fix UnboundLocalError: cannot access local variable 'version' where i…
ZX-ModelCloud Dec 5, 2024
4d32b48
Merge remote-tracking branch 'origin/add_meta_info' into add_meta_info
ZX-ModelCloud Dec 5, 2024
c09da17
make version var short
Qubitium Dec 5, 2024
e9c5358
Update import_utils.py
Qubitium Dec 5, 2024
77dec80
fix unittest
ZX-ModelCloud Dec 5, 2024
d5857f8
Merge remote-tracking branch 'origin/add_meta_info' into add_meta_info
ZX-ModelCloud Dec 5, 2024
556002c
use assertLessEqual
LRL-ModelCloud Dec 5, 2024
d90dad0
Merge remote-tracking branch 'origin/add_meta_info' into add_meta_info
LRL-ModelCloud Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .constants import GPTQ_CONFIG
from .data import get_dataset, prepare_dataset
from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen
from ..version import __version__ as optimum_version


if is_accelerate_available():
Expand All @@ -46,13 +47,15 @@
from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
from auto_gptq import __version__ as autogptq_version

if is_gptqmodel_available():
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
from gptqmodel.version import __version__ as gptqmodel_version

logger = getLogger(__name__)

Expand Down Expand Up @@ -80,15 +83,17 @@ def __init__(
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
use_cuda_fp16: bool = False,
checkpoint_format: str = "gptq",
meta: Optional[Dict[str, any]] = None,
backend: Optional[str] = None,
use_cuda_fp16: bool = False,
model_seqlen: Optional[int] = None,
block_name_to_quantize: Optional[str] = None,
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
exllama_config: Dict[str, Any] = None,
exllama_config: Optional[Dict[str, Any]] = None,
max_input_length: Optional[int] = None,
cache_block_outputs: Optional[bool] = True,
modules_in_block_to_quantize: Optional[List[List[str]]] = None,
Expand Down Expand Up @@ -117,6 +122,14 @@ def __init__(
Whether to perform sequential quantization even within a single Transformer block.
Instead of quantizing the entire block at once, we perform layer-wise quantization.
As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers.
checkpoint_format (`str`, *optional*, defaults to `gptq`):
GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only.
meta (`Dict[str, any]`, *optional*):
Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
backend (`str`, *optional*):
Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only
valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
use_cuda_fp16 (`bool`, defaults to `False`):
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
model_seqlen (`Optional[int]`, defaults to `None`):
Expand Down Expand Up @@ -152,6 +165,9 @@ def __init__(
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
self.checkpoint_format = checkpoint_format.lower()
self.meta = meta
self.backend = backend.lower() if backend is not None else None
self.use_cuda_fp16 = use_cuda_fp16
self.model_seqlen = model_seqlen
self.block_name_to_quantize = block_name_to_quantize
Expand All @@ -164,7 +180,6 @@ def __init__(
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.checkpoint_format = checkpoint_format

self.serialization_keys = [
"bits",
Expand All @@ -177,6 +192,7 @@ def __init__(
"quant_method",
"modules_in_block_to_quantize",
"checkpoint_format",
"meta",
]

if self.bits not in [2, 3, 4, 8]:
Expand All @@ -198,15 +214,17 @@ def __init__(
)
self.exllama_version = self.exllama_config["version"]

def select_quant_linear(self, pack: bool, device_map: Union[str, dict]):
def select_quant_linear(self, device_map: Union[str, dict]):
if is_gptqmodel_available():
self.quant_linear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=self.sym,
checkpoint_format=self.checkpoint_format,
meta=self.meta,
device_map=device_map,
pack=pack,
backend=self.backend,
)
else:
self.quant_linear = hf_select_quant_linear(
Expand All @@ -225,6 +243,20 @@ def to_dict(self):
gptq_dict = {}
for key in self.serialization_keys:
gptq_dict[key] = getattr(self, key)

if gptq_dict.get("meta") is None:
gptq_dict["meta"] = {}

meta = gptq_dict["meta"]
# store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer
if meta.get("quantizer") is None:
meta["quantizer"] = [f"optimum:{optimum_version}"]

if is_gptqmodel_available():
meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}")
elif is_auto_gptq_available():
meta["quantizer"].append(f"auto_gptq:{autogptq_version}")

return gptq_dict

@classmethod
Expand Down Expand Up @@ -263,7 +295,7 @@ def convert_model(self, model: nn.Module, **kwargs):
)
del layers_to_be_replaced[name]

self.select_quant_linear(pack=False, device_map=kwargs.get("device_map", None))
self.select_quant_linear(device_map=kwargs.get("device_map", None))

self._replace_by_quant_layers(model, layers_to_be_replaced)

Expand Down Expand Up @@ -379,10 +411,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
gptq_supports_cpu = (
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or (
is_gptqmodel_available()
and version.parse(importlib.metadata.version("gptqmodel")) > version.parse("1.3.1")
)
) or is_gptqmodel_available()

if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError(
Expand Down Expand Up @@ -663,18 +692,12 @@ def tmp(_, input, output):
# Step 5: Any post-initialization that require device information, for example buffers initialization on device.
model = self.post_init_model(model)

# convert gptqmodel internal gptq_v2 format to v1 for saving/compat
# sym=False is valid for gptq_v2 format only. for sym=True, need to convert to v1 before save.
if self.sym and self.checkpoint_format == "gptq_v2":
model = hf_convert_gptq_v2_to_v1_format(model, self.bits, self.quant_linear)
self.checkpoint_format = "gptq"

torch.cuda.empty_cache()
if hasattr(torch, "xpu"):
torch.xpu.empty_cache()
return model

def post_init_model(self, model, **kwargs):
def post_init_model(self, model):
"""
Post-initialization that require device information, for example buffers initialization on device.

Expand All @@ -695,8 +718,8 @@ def post_init_model(self, model, **kwargs):
class StoreAttr(object):
pass

if is_gptqmodel_available() and self.checkpoint_format == "gptq":
model = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear)
if is_gptqmodel_available():
model, _ = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear, self.checkpoint_format, self.meta)

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
Expand Down Expand Up @@ -727,7 +750,7 @@ def pack_model(
layers = get_layers(model)
layers = {n: layers[n] for n in quantizers}

self.select_quant_linear(pack=True, device_map=model.hf_device_map)
self.select_quant_linear(device_map=model.hf_device_map)

self._replace_by_quant_layers(model, quantizers)
qlayers = get_layers(model, [self.quant_linear])
Expand Down Expand Up @@ -765,6 +788,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).

"""

# convert gptqmodel internal gptq_v2 format to v1 for max compatibility
model, converted = hf_convert_gptq_v2_to_v1_format(model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta)
if converted:
self.checkpoint_format = "gptq"

os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -871,7 +900,7 @@ def load_quantized_model(
quantizer.exllama_version = quantizer.exllama_config["version"]
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)
model = quantizer.convert_model(model, device_map=device_map)

if no_split_module_classes is None:
no_split_module_classes = quantizer.get_no_split_module_classes(model)
Expand Down
16 changes: 12 additions & 4 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.3.99") # Allows 1.4.0.dev0


# This is the minimal required version to support some ONNX Runtime features
Expand Down Expand Up @@ -139,17 +140,24 @@ def is_datasets_available():

def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
v = version.parse(importlib_metadata.version("auto_gptq"))
if v >= AUTOGPTQ_MINIMUM_VERSION:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported"
)


def is_gptqmodel_available():
return _gptqmodel_available
if _gptqmodel_available:
v = version.parse(importlib_metadata.version("gptqmodel"))
if v >= GPTQMODEL_MINIMUM_VERSION:
return True
else:
raise ImportError(
f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported"
)


@contextmanager
Expand Down
43 changes: 27 additions & 16 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class GPTQTest(unittest.TestCase):
bits = 4
group_size = 128
desc_act = False
sym = True
disable_exllama = True
exllama_config = None
cache_block_outputs = True
Expand All @@ -73,6 +74,7 @@ def setUpClass(cls):
"""

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.config = AutoConfig.from_pretrained(cls.model_name)

cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization
Expand All @@ -87,6 +89,7 @@ def setUpClass(cls):
dataset=cls.dataset,
group_size=cls.group_size,
desc_act=cls.desc_act,
sym=cls.sym,
disable_exllama=cls.disable_exllama,
exllama_config=cls.exllama_config,
cache_block_outputs=cls.cache_block_outputs,
Expand Down Expand Up @@ -116,13 +119,20 @@ def test_quantized_layers_class(self):
"""

if is_gptqmodel_available():
if hasattr(self.config, "quantization_config"):
checkpoint_format = self.config.quantization_config.get("checkpoint_format")
meta = self.config.quantization_config.get("meta")
else:
checkpoint_format = "gptq"
meta = None
QuantLinear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=True,
sym=self.sym,
device_map=self.device_map_for_quantization,
pack=False,
checkpoint_format=checkpoint_format,
meta=meta,
)
else:
QuantLinear = hf_select_quant_linear(
Expand All @@ -133,10 +143,10 @@ def test_quantized_layers_class(self):
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
)
self.assertTrue(self.quantized_model.model.layers[0].mlp.gate_proj.__class__ == QuantLinear)
self.assertEqual(self.quantized_model.model.layers[0].mlp.gate_proj.__class__, QuantLinear)

def check_quantized_layers_type(self, model, value):
self.assertTrue(model.model.layers[0].mlp.gate_proj.QUANT_TYPE == value)
self.assertEqual(model.model.layers[0].mlp.gate_proj.QUANT_TYPE, value)

def test_serialization(self):
"""
Expand All @@ -161,7 +171,7 @@ def test_serialization(self):
if is_auto_gptq_available() and not is_gptqmodel_available():
quant_type = "cuda-old" if self.disable_exllama else "exllama"
else:
quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "cuda"
quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "exllama"

self.check_quantized_layers_type(quantized_model_from_saved, quant_type)

Expand All @@ -179,16 +189,19 @@ def test_serialization(self):
class GPTQTestCUDA(GPTQTest):
device_map_for_quantization = "cuda"
device_for_inference = 0
expected_compression_ratio = 1.66
expected_compression_ratio = 1.2577
expected_fp16_perplexity = 38
expected_quantized_perplexity = 45


def test_perplexity(self):
"""
A simple test to check if the model conversion has been done correctly by checking on the
the perplexity of the converted models
"""

self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)
self.assertLessEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertLessEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)


class GPTQTestExllama(GPTQTestCUDA):
Expand All @@ -199,6 +212,7 @@ class GPTQTestExllama(GPTQTestCUDA):
class GPTQTestActOrder(GPTQTestCUDA):
disable_exllama = True
desc_act = True
expected_quantized_perplexity = 46

def test_serialization(self):
# act_order don't work with qlinear_cuda kernel
Expand Down Expand Up @@ -282,7 +296,6 @@ def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
"""

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
Expand All @@ -296,16 +309,13 @@ def test_exllama_serialization(self):
save_folder=tmpdirname,
device_map={"": self.device_for_inference},
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2")
self.check_quantized_layers_type(quantized_model_from_saved, "exllama" if is_gptqmodel_available else "exllamav2")

# transformers and auto-gptq compatibility
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestNoBlockCaching(GPTQTestCUDA):
Expand All @@ -318,11 +328,12 @@ class GPTQTestModuleQuant(GPTQTestCUDA):
["self_attn.q_proj"],
["mlp.gate_proj"],
]
expected_compression_ratio = 1.577
expected_compression_ratio = 1.068
expected_quantized_perplexity = 39

def test_not_converted_layers(self):
# self_attention.dense should not be converted
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__ == "Linear")
self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__, "Linear")


class GPTQUtilsTest(unittest.TestCase):
Expand Down