-
Notifications
You must be signed in to change notification settings - Fork 2.1k
add gptqmodel support #2247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gptqmodel support #2247
Changes from 19 commits
a69fd22
3b87bae
e17c4ea
01c7429
c80fffd
e887299
60c03af
ec4d6fe
b54d034
12ab8a0
c8c3d8e
946d1d7
4f11d86
4f13f7b
9fcdd02
17440c4
c206e7b
e0439fd
1f79dae
c15a302
6a3adc6
59932fd
75ddd5e
dab7a54
d5e55b6
540e5af
fa3ab05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,6 +107,32 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite | |
| config = LoraConfig(target_modules="all-linear", ...) | ||
| ``` | ||
|
|
||
| ## GPTQ quantization | ||
|
|
||
| You can learn more about gptq based `[2, 3, 4, 8]` bits quantization at [GPTQModel](https://github.com/ModelCloud/GPTQModel) and [HF GPTQ Doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization/gptq.md). PEFT post-quant training can use both [GPTQModel](https://github.com/ModelCloud/GPTQModel) or [AutoGPTQ](https://github.com/autogptq/autogptq) libraries but we recommend GPTQModel as AutoGPTQ will be deprecated in a future release. | ||
Qubitium marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```bash | ||
| # gptqmodel install | ||
| pip install gptqmodel --no-build-isolation | ||
| ``` | ||
|
|
||
| ```py | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig | ||
|
|
||
| model_id = "facebook/opt-125m" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| gptq_config = GPTQConfig(bits=4, group_size=128, dataset="wikitext2", tokenizer=tokenizer) | ||
|
|
||
| quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When trying to run this locally with 2 CUDA devices, I encountered a CUDA error after 50% progress: Is this a known problem? Using 1 CUDA device or setting I suspect that the error occurs at the "switch" from GPU 0 to GPU 1, since that's exactly after half the layers when using
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will double check this to see if a) accelerate specific or b) OPT specific.
For next GPTQModel CI tests PR, I would recommend we move all model testings from OPT to Llama 1B. I believe OPT was chosen due to the tiny size but in our experience, but there are some strange issues with the OPT modeling code (that I can't recall) that causes strange issues here and there. We recently dropped all CI OPT tests in factor of Llama for this reason. Again, I can't seem to recall the exact reasons. =( Basically no one uses OPT anymore and any modeling changes is heavily favoriing Llama so any fringe bugs are much less likely to occur on
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that opt is very outdated at this point, and we mainly use it since it's small, but at least for PEFT it hasn't caused any problems yet. I ran the code above using Thus it's unlikely to be related to the model architecture. |
||
|
|
||
| # save quantized model | ||
| quantized_model.save_pretrained("./opt-125m-gptq") | ||
| tokenizer.save_pretrained("./opt-125m-gptq") | ||
| ``` | ||
|
|
||
| Once quantized, you can post-train gptq models using normal PEFT apis. | ||
Qubitium marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## AQLM quantization | ||
|
|
||
| Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work, as this make argument is never called anywhere. What I meant is to just add the line to
tests_common_gpuabove, which is already called in the appropriate setting.