Skip to content

Add HQQ quant loader for ooba#4888

Merged
oobabooga merged 9 commits intooobabooga:devfrom
waters222:main
Dec 19, 2023
Merged

Add HQQ quant loader for ooba#4888
oobabooga merged 9 commits intooobabooga:devfrom
waters222:main

Conversation

@waters222
Copy link
Copy Markdown
Contributor

@waters222 waters222 commented Dec 11, 2023

Checklist:

@oobabooga
Copy link
Copy Markdown
Owner

This looks promising. Have you made any tests on perplexity, speed, and maximum context size?

model_dir = f'{shared.args.model_dir}/{model_name}'
logger.warning(f"loading HQQ model from {model_dir}")
model = HQQModelForCausalLM.from_quantized(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could probably just return model and let ooba handle the tokenizer options like AWQ does.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call. changed it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good.

@waters222
Copy link
Copy Markdown
Contributor Author

waters222 commented Dec 12, 2023

This looks promising. Have you made any tests on perplexity, speed, and maximum context size?

the performance is not great at the moment.

Output generated in 38.99 seconds (5.67 tokens/s, 221 tokens, context 221, seed 581501775)

The PPL results should refer to @mobicham 's
the VRAM when just loaded with model is 19656MiB / 24564MiB
and max context I got without running into OOM is Output generated in 166.80 seconds (5.60 tokens/s, 934 tokens, context 3257, seed 1203111843)
so its about 4096 for 24G VRAM for Mixtral model

@mobicham
Copy link
Copy Markdown

mobicham commented Dec 12, 2023

You should not use the Instruct model for wikitext ppl, you should use the base model. Here are the numbers with a comparison vs. bitsandbytes:

Wikitext2 PPL/Memory: HQQ vs bitsandbytes (BNB)

#8-bit (group_size=128)
Mixtral-8x7B-v0.1 / BNB : 3.64 | (54.5 GB)
Mixtral-8x7B-v0.1 / HQQ : 3.63 | (47 GB)

#4-bit (group_size=64)
Mixtral-8x7B-v0.1 / BNB : 3.97 | (27 GB)
Mixtral-8x7B-v0.1 / HQQ : 3.79 | (26 GB)

#3-bit (group_size=128)
Mixtral-8x7B-v0.1 / HQQ : 4.76 | (21.8 GB)

#2-bit (group_size=16 | scale_g128/zero=8-bit):
Mixtral-8x7B-v0.1 / HQQ : 5.90 | (18 GB)

More numbers for Llama2 and OpenCLIP: https://mobiusml.github.io/hqq_blog/

@waters222 waters222 requested a review from cal066 December 12, 2023 19:26
@waters222
Copy link
Copy Markdown
Contributor Author

waters222 commented Dec 12, 2023

I just did some PPL test. here is the result
Fist one is BnB 4bit

image

@mobicham
Copy link
Copy Markdown

Try with https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py , you should get the same numbers as the ones I posted.

@waters222 waters222 requested a review from cal066 December 14, 2023 03:19
Copy link
Copy Markdown
Contributor

@cal066 cal066 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, loader looks good.

@waters222
Copy link
Copy Markdown
Contributor Author

@oobabooga anymore comment before merge ?

@oobabooga oobabooga changed the base branch from main to dev December 15, 2023 05:58
@oobabooga
Copy link
Copy Markdown
Owner

oobabooga commented Dec 15, 2023

Here is a test:

Model Perplexity (wikitext) VRAM (MB) Model size (MB)
mobiuslabsgmbh_Llama-2-13b-hf-4bit_g64-HQQ 4.33972 8463 7246

The methodology is the same as in this blog post of mine: https://oobabooga.github.io/blog/posts/gptq-awq-exl2-llamacpp/

The key takeaway is that no other quantization has lower VRAM and lower perplexity than this, so it is in the Pareto frontier for VRAM vs perplexity. But in terms of model size, llama-2-13b-EXL2-4.400b is 142MB smaller and has 0.001 lower perplexity.

These are my thoughts:

  • The quantization accuracy at 4-bit precision is indeed very competitive, and it has the advantage of being fast and computationally cheap to quantize.
  • Since there is an option to use pure PyTorch for inference, without the need compiled extensions, it should work on anything, including Intel Arc, AMD, and Metal GPUs. Which is a huge plus, since Intel Arc in particular is severely limited at the moment.
  • The speed is very slow with the PYTORCH_COMPILED backend. I haven't tried the ATEN backend.
  • It runs Mixtral in 2-bit precision with 24GB VRAM, but in my subjective testing the degradation relative to the llama.cpp Q4_K_M is too big, and the speed is also slower (even though llama.cpp is computing part of the layers using the CPU).

@mobicham two questions:

  1. Any chance you could turn your repository into a pip-installable package? Requirements in git+ format are kind of cumbersome to include in a requirements.txt and will make the installation time for this project longer.
  2. Are there planned optimizations to increase the tokens/second somehow? Is ATEN a lot faster?

@oobabooga oobabooga mentioned this pull request Dec 15, 2023
1 task
@mobicham
Copy link
Copy Markdown

mobicham commented Dec 15, 2023

@oobabooga thank you very much for your comments!

  1. pip installable package: sure will do !
  2. How is the speed comparatively measured? The reference should be the exact same model with fp16. Also, there are many ways to implement Llama for example. In VLLM we see 10x speed-up compared to Hugging Face simply because the Llama architecture is well optimized in VLLM.
    For the same Hugging Face Llama architecture, our tests show that HQQ is actually a bit faster than GPTQ/AWQ. But the thing to understand is that in GPTQ/AWQ, when you load the model via from_quantized() they do other stuff to the model to make it faster (replace Pytorch version of layernorm, merge some layers, etc..), so technically it's no longer the "same model".

Here are some solutions:

  • Use VLLM architecture instead of Hugging Face. We already have that for Llama and it's working fine with HQQ. The weights are not compatible with HF weights, but we support saving/loading for this so we can have separate weights for it: https://github.com/mobiusml/hqq/blob/master/examples/vllm/llama2_example.py
  • For the specific case of 4-bit with some constraints, there's this new Pytorch function that seems to dequantize() + matmul but there's no documentation so far. I will ask on their github to see how we can use this:
    https://pytorch.org/cppdocs/api/function_namespaceat_1adeda9630914278ac02d7fd758da19e3d.html
    I think really the best use-case in terms of speed with the pure Pytorch backend would be matching the fp16 performance.
  • We are actively working on improving and making things faster. We already have a version that is 30-40% faster for 2-bit, not pushed yet because I am still testing it and would require some refactoring but just to say that we are actively working on it.

@mobicham
Copy link
Copy Markdown

@oobabooga I guess you tested the llama.cpp on Mac with unified memory? That might explain the fast cpu > gpu data transfer. On a Linux machine, that is not really an option, at least with Pytorch, the .cuda() is too slow.

@oobabooga
Copy link
Copy Markdown
Owner

@oobabooga I guess you tested the llama.cpp on Mac with unified memory? That might explain the fast cpu > gpu data transfer. On a Linux machine, that is not really an option, at least with Pytorch, the .cuda() is too slow.

No, I use Linux with a 3090. llama.cpp doesn't do CPU offloading. When you don't export all layers to the GPU, the remaining layers are computed in the CPU itself.

For Mixtral, since only 2 experts with 7b parameters are used at a time, the speed ends up decent even though the CPU layers are bottlenecked by the ~20 GB/s RAM bandwidth.

@oobabooga
Copy link
Copy Markdown
Owner

About VLLM, another PR is open about integrating it (#4860); I don't know how practical that would be considering that the inference code in this project relies a lot on the transformers library. It's something that I have to investigate.

@mobicham
Copy link
Copy Markdown

@oobabooga

About VLLM, another PR is open about integrating it (#4860); I don't know how practical that would be considering that the inference code in this project relies a lot on the transformers library. It's something that I have to investigate.

Cool ! VLLM is indeed quite different. I have some ideas to make the HF models faster, will give it a try in the upcoming days.

It runs Mixtral in 2-bit precision with 24GB VRAM, but in my subjective testing the degradation relative to the llama.cpp Q4_K_M is too big, and the speed is also slower (even though llama.cpp is computing part of the layers using the CPU).

Q4 is 4-bit, going from 4-bit to 2-bit leads to a big drop in quality indeed. How does it compare to Q2_K ?
By the way, we will publish shortly a new version of the Mixtral 2-bit model that is much better in quality (ppl: 4.69 vs 5.90 @~18GB)

@oobabooga
Copy link
Copy Markdown
Owner

How does it compare to Q2_K ?

I haven't tested. In the 2-bit domain, I got impressive results with QuIP# for llama-2-70b-chat. Q2_K is huge in comparison (like 50% larger), as it's closer to 3-bit than 2-bit, but the author of the k-quants method used in llama.cpp claims to have a new version of the method that is closer to QuIP#. He pushed some examples to https://huggingface.co/ikawrakow/llama-v2-2bit-gguf but didn't release the code yet. See the discussion here.

I have some ideas to make the HF models faster, will give it a try in the upcoming days.

By the way, we will publish shortly a new version of the Mixtral 2-bit model that is much better in quality (ppl: 4.69 vs 5.90 @~18GB)

Exciting news -- I look forward to trying the updated version. Mixtral 3bit would be interesting as well, as it should fit in 24GB VRAM.

@mobicham
Copy link
Copy Markdown

I did a quick comparison with QuiP# 4-bit. I forgot to compress the scaling which should reduce the memory but here's a rough comparison (PPL/Memory). Will play with Quip# 2-bit later:

Llama-2-7B: 
HQQ - 4-bit (g_64)   : 5.30 | 4.6  GB
QUIP-sharp 4-bit     : 5.37 | 4.3  GB

Llama-2-13B: 
HQQ - 4-bit (g_128)  : 4.74 | 7.9  GB 
QUIP-sharp 4-bit     : 4.74 | 7.42 GB

Llama-2-70B: 
HQQ - 4-bit (g_128)  : 3.21 | 35.97 GB
QUIP-sharp 4-bit -   : 3.22 | 34.4  GB 

Regarding the new MIxtral models, here are the links:
Base: https://huggingface.co/mobiuslabsgmbh/Mixtral-8x7B-v0.1-hf-attn-4bit-moe-2bit-HQQ
Instruct: https://huggingface.co/mobiuslabsgmbh/Mixtral-8x7B-Instruct-v0.1-hf-attn-4bit-moe-2bit-HQQ

Numbers:

HQQ 3-bit (group_size=128): 4.76 | (21.8 GB)
HQQ Old 2-bit model:        5.90 | (18 GB)  
HQQ New 2-bit model:        4.69 | (19.2 GB)

It's true that the 3-bit models should work on 24 GB but with smaller context window. On disk, the new model is only 0.20 GB more (18.2GB vs. 18GB) but for some reason it takes an extra 1 GB in VRAM.

By the way, now you can install hqq via pip: pip install hqq

@oobabooga
Copy link
Copy Markdown
Owner

0.2GB for a 1.14 drop in ppl is a massive improvement. Very impressive.

With the pypi package, the PR looks good to merge now.

@mobicham If I could have two additional suggestions for future HQQ versions:

  1. Exporting the model weights as .safetensors instead of .pt would make them faster to load.
  2. According to Update IPEX to 2.1.10+xpu #4931 (comment), HQQ doesn't work on an Intel Arc GPU, so there may be some CUDA-specific code somewhere in the repository.

@oobabooga oobabooga merged commit 674be9a into oobabooga:dev Dec 19, 2023
@mobicham
Copy link
Copy Markdown

Thanks @oobabooga !

  1. So the reason why we ended-up using torch.save is because safetensors only support a dictionary in the form string:torch.tensor, but we need to store some string meta-data. I believe we can switch to safetensors once the bitpacking/unpacking logic is consolidated.
  2. Oh I see, I think the problem is because the code uses cuda() to transfer data to the GPU, we should use to(device) instead and specify the device somewhere (llike 'xpu' for arc). I think we should also specify the dtype because apparently some AMD gpus don't support half-precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants