-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Feature][Kernel] Support bitsandbytes quantization and QLoRA #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
126e816
[model] support bitsandbytes/QLoRA (#4033)
chenqianfzh f4743d0
Merge branch 'main' into qian/qlora
chenqianfzh eba8541
Merge branch 'main' into qian/qlora
chenqianfzh 64ad1a3
update error messages
chenqianfzh aacab4a
Merge branch 'main' into qian/qlora
chenqianfzh 973fd63
Merge branch 'main' into qian/qlora
chenqianfzh 1f8aea9
Merge branch 'main' into qian/qlora
chenqianfzh 161c792
add comments about bitandbytes bug
chenqianfzh 5264d57
Merge branch 'main' into qian/qlora
chenqianfzh fbdff73
Revert "add comments about bitandbytes bug"
chenqianfzh 5c25ae3
add comment about bitsandbytes bug
chenqianfzh 25b7e75
Merge branch 'main' into qian/qlora
chenqianfzh e16bcb6
update per comments
chenqianfzh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| """ | ||
| This example shows how to use LoRA with different quantization techniques | ||
| for offline inference. | ||
|
|
||
| Requires HuggingFace credentials for access. | ||
| """ | ||
|
|
||
| import gc | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import torch | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams | ||
| from vllm.lora.request import LoRARequest | ||
|
|
||
|
|
||
| def create_test_prompts( | ||
| lora_path: str | ||
| ) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: | ||
| return [ | ||
| # this is an example of using quantization without LoRA | ||
| ("My name is", | ||
| SamplingParams(temperature=0.0, | ||
| logprobs=1, | ||
| prompt_logprobs=1, | ||
| max_tokens=128), None), | ||
| # the next three examples use quantization with LoRA | ||
| ("my name is", | ||
| SamplingParams(temperature=0.0, | ||
| logprobs=1, | ||
| prompt_logprobs=1, | ||
| max_tokens=128), | ||
| LoRARequest("lora-test-1", 1, lora_path)), | ||
| ("The capital of USA is", | ||
| SamplingParams(temperature=0.0, | ||
| logprobs=1, | ||
| prompt_logprobs=1, | ||
| max_tokens=128), | ||
| LoRARequest("lora-test-2", 1, lora_path)), | ||
| ("The capital of France is", | ||
| SamplingParams(temperature=0.0, | ||
| logprobs=1, | ||
| prompt_logprobs=1, | ||
| max_tokens=128), | ||
| LoRARequest("lora-test-3", 1, lora_path)), | ||
| ] | ||
|
|
||
|
|
||
| def process_requests(engine: LLMEngine, | ||
| test_prompts: List[Tuple[str, SamplingParams, | ||
| Optional[LoRARequest]]]): | ||
| """Continuously process a list of prompts and handle the outputs.""" | ||
| request_id = 0 | ||
|
|
||
| while test_prompts or engine.has_unfinished_requests(): | ||
| if test_prompts: | ||
| prompt, sampling_params, lora_request = test_prompts.pop(0) | ||
| engine.add_request(str(request_id), | ||
| prompt, | ||
| sampling_params, | ||
| lora_request=lora_request) | ||
| request_id += 1 | ||
|
|
||
| request_outputs: List[RequestOutput] = engine.step() | ||
| for request_output in request_outputs: | ||
| if request_output.finished: | ||
| print("----------------------------------------------------") | ||
| print(f"Prompt: {request_output.prompt}") | ||
| print(f"Output: {request_output.outputs[0].text}") | ||
|
|
||
|
|
||
| def initialize_engine(model: str, quantization: str, | ||
| lora_repo: Optional[str]) -> LLMEngine: | ||
| """Initialize the LLMEngine.""" | ||
|
|
||
| if quantization == "bitsandbytes": | ||
| # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique. | ||
| # It quantizes the model when loading, with some config info from the | ||
| # LoRA adapter repo. So need to set the parameter of load_format and | ||
| # qlora_adapter_name_or_path as below. | ||
| engine_args = EngineArgs( | ||
| model=model, | ||
| quantization=quantization, | ||
| qlora_adapter_name_or_path=lora_repo, | ||
| load_format="bitsandbytes", | ||
| enable_lora=True, | ||
| max_lora_rank=64, | ||
| # set it only in GPUs of limited memory | ||
| enforce_eager=True) | ||
| else: | ||
| engine_args = EngineArgs( | ||
| model=model, | ||
| quantization=quantization, | ||
| enable_lora=True, | ||
| max_loras=4, | ||
| # set it only in GPUs of limited memory | ||
| enforce_eager=True) | ||
| return LLMEngine.from_engine_args(engine_args) | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function that sets up and runs the prompt processing.""" | ||
|
|
||
| test_configs = [{ | ||
| "name": "qlora_inference_example", | ||
| 'model': "huggyllama/llama-7b", | ||
| 'quantization': "bitsandbytes", | ||
| 'lora_repo': 'timdettmers/qlora-flan-7b' | ||
| }, { | ||
| "name": "AWQ_inference_with_lora_example", | ||
| 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', | ||
| 'quantization': "awq", | ||
| 'lora_repo': 'jashing/tinyllama-colorist-lora' | ||
| }, { | ||
| "name": "GPTQ_inference_with_lora_example", | ||
| 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', | ||
| 'quantization': "gptq", | ||
| 'lora_repo': 'jashing/tinyllama-colorist-lora' | ||
| }] | ||
|
|
||
| for test_config in test_configs: | ||
| print( | ||
| f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" | ||
| ) | ||
| engine = initialize_engine(test_config['model'], | ||
| test_config['quantization'], | ||
| test_config['lora_repo']) | ||
| lora_path = snapshot_download(repo_id=test_config['lora_repo']) | ||
| test_prompts = create_test_prompts(lora_path) | ||
| process_requests(engine, test_prompts) | ||
|
|
||
| # Clean up the GPU memory for the next test | ||
| del engine | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,3 +35,6 @@ aiohttp | |
|
|
||
| # Multimodal | ||
| pillow | ||
|
|
||
| # quantization | ||
| bitsandbytes==0.42.0 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| '''Tests whether bitsandbytes computation is enabled correctly. | ||
|
|
||
| Run `pytest tests/quantization/test_bitsandbytes.py`. | ||
| ''' | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm import SamplingParams | ||
| from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS | ||
|
|
||
| capability = torch.cuda.get_device_capability() | ||
| capability = capability[0] * 10 + capability[1] | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(), | ||
| reason='bitsandbytes is not supported on this GPU type.') | ||
| def test_load_bnb_model(vllm_runner) -> None: | ||
| llm = vllm_runner('huggyllama/llama-7b', | ||
| quantization='bitsandbytes', | ||
| load_format='bitsandbytes', | ||
| enforce_eager=True) | ||
|
|
||
| model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model | ||
|
|
||
| # check the weights in MLP & SelfAttention are quantized to torch.uint8 | ||
| qweight = model.model.layers[0].mlp.gate_up_proj.qweight | ||
| assert qweight.dtype == torch.uint8, ( | ||
| f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') | ||
|
|
||
| qweight = model.model.layers[0].mlp.down_proj.qweight | ||
| assert qweight.dtype == torch.uint8, ( | ||
| f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') | ||
|
|
||
| qweight = model.model.layers[0].self_attn.o_proj.qweight | ||
| assert qweight.dtype == torch.uint8, ( | ||
| f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') | ||
|
|
||
| qweight = model.model.layers[0].self_attn.qkv_proj.qweight | ||
| assert qweight.dtype == torch.uint8, ( | ||
| f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') | ||
|
|
||
| # some weights should not be quantized | ||
| weight = model.lm_head.weight | ||
| assert weight.dtype != torch.uint8, ( | ||
| 'lm_head weight dtype should not be torch.uint8') | ||
|
|
||
| weight = model.model.embed_tokens.weight | ||
| assert weight.dtype != torch.uint8, ( | ||
| 'embed_tokens weight dtype should not be torch.uint8') | ||
|
|
||
| weight = model.model.layers[0].input_layernorm.weight | ||
| assert weight.dtype != torch.uint8, ( | ||
| 'input_layernorm weight dtype should not be torch.uint8') | ||
|
|
||
| weight = model.model.layers[0].post_attention_layernorm.weight | ||
| assert weight.dtype != torch.uint8, ( | ||
| 'input_layernorm weight dtype should not be torch.uint8') | ||
|
|
||
| # check the output of the model is expected | ||
| sampling_params = SamplingParams(temperature=0.0, | ||
| logprobs=1, | ||
| prompt_logprobs=1, | ||
| max_tokens=8) | ||
|
|
||
| prompts = ['That which does not kill us', 'To be or not to be,'] | ||
| expected_outputs = [ | ||
| 'That which does not kill us makes us stronger.', | ||
| 'To be or not to be, that is the question.' | ||
| ] | ||
| outputs = llm.generate(prompts, sampling_params=sampling_params) | ||
|
|
||
| assert len(outputs) == len(prompts) | ||
|
|
||
| for index in range(len(outputs)): | ||
| # compare the first line of the output | ||
| actual_output = outputs[index][1][0].split('\n', 1)[0] | ||
| expected_output = expected_outputs[index].split('\n', 1)[0] | ||
| assert actual_output == expected_output, ( | ||
| f'Expected: {expected_output}, but got: {actual_output}') | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.