Skip to content

Commit eebe0fd

Browse files
chenqianfzhjimpang
authored andcommitted
[Feature][Kernel] Support bitsandbytes quantization and QLoRA (vllm-project#4776)
1 parent 66822d5 commit eebe0fd

File tree

11 files changed

+752
-8
lines changed

11 files changed

+752
-8
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
This example shows how to use LoRA with different quantization techniques
3+
for offline inference.
4+
5+
Requires HuggingFace credentials for access.
6+
"""
7+
8+
import gc
9+
from typing import List, Optional, Tuple
10+
11+
import torch
12+
from huggingface_hub import snapshot_download
13+
14+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15+
from vllm.lora.request import LoRARequest
16+
17+
18+
def create_test_prompts(
19+
lora_path: str
20+
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
21+
return [
22+
# this is an example of using quantization without LoRA
23+
("My name is",
24+
SamplingParams(temperature=0.0,
25+
logprobs=1,
26+
prompt_logprobs=1,
27+
max_tokens=128), None),
28+
# the next three examples use quantization with LoRA
29+
("my name is",
30+
SamplingParams(temperature=0.0,
31+
logprobs=1,
32+
prompt_logprobs=1,
33+
max_tokens=128),
34+
LoRARequest("lora-test-1", 1, lora_path)),
35+
("The capital of USA is",
36+
SamplingParams(temperature=0.0,
37+
logprobs=1,
38+
prompt_logprobs=1,
39+
max_tokens=128),
40+
LoRARequest("lora-test-2", 1, lora_path)),
41+
("The capital of France is",
42+
SamplingParams(temperature=0.0,
43+
logprobs=1,
44+
prompt_logprobs=1,
45+
max_tokens=128),
46+
LoRARequest("lora-test-3", 1, lora_path)),
47+
]
48+
49+
50+
def process_requests(engine: LLMEngine,
51+
test_prompts: List[Tuple[str, SamplingParams,
52+
Optional[LoRARequest]]]):
53+
"""Continuously process a list of prompts and handle the outputs."""
54+
request_id = 0
55+
56+
while test_prompts or engine.has_unfinished_requests():
57+
if test_prompts:
58+
prompt, sampling_params, lora_request = test_prompts.pop(0)
59+
engine.add_request(str(request_id),
60+
prompt,
61+
sampling_params,
62+
lora_request=lora_request)
63+
request_id += 1
64+
65+
request_outputs: List[RequestOutput] = engine.step()
66+
for request_output in request_outputs:
67+
if request_output.finished:
68+
print("----------------------------------------------------")
69+
print(f"Prompt: {request_output.prompt}")
70+
print(f"Output: {request_output.outputs[0].text}")
71+
72+
73+
def initialize_engine(model: str, quantization: str,
74+
lora_repo: Optional[str]) -> LLMEngine:
75+
"""Initialize the LLMEngine."""
76+
77+
if quantization == "bitsandbytes":
78+
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
79+
# It quantizes the model when loading, with some config info from the
80+
# LoRA adapter repo. So need to set the parameter of load_format and
81+
# qlora_adapter_name_or_path as below.
82+
engine_args = EngineArgs(
83+
model=model,
84+
quantization=quantization,
85+
qlora_adapter_name_or_path=lora_repo,
86+
load_format="bitsandbytes",
87+
enable_lora=True,
88+
max_lora_rank=64,
89+
# set it only in GPUs of limited memory
90+
enforce_eager=True)
91+
else:
92+
engine_args = EngineArgs(
93+
model=model,
94+
quantization=quantization,
95+
enable_lora=True,
96+
max_loras=4,
97+
# set it only in GPUs of limited memory
98+
enforce_eager=True)
99+
return LLMEngine.from_engine_args(engine_args)
100+
101+
102+
def main():
103+
"""Main function that sets up and runs the prompt processing."""
104+
105+
test_configs = [{
106+
"name": "qlora_inference_example",
107+
'model': "huggyllama/llama-7b",
108+
'quantization': "bitsandbytes",
109+
'lora_repo': 'timdettmers/qlora-flan-7b'
110+
}, {
111+
"name": "AWQ_inference_with_lora_example",
112+
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
113+
'quantization': "awq",
114+
'lora_repo': 'jashing/tinyllama-colorist-lora'
115+
}, {
116+
"name": "GPTQ_inference_with_lora_example",
117+
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
118+
'quantization': "gptq",
119+
'lora_repo': 'jashing/tinyllama-colorist-lora'
120+
}]
121+
122+
for test_config in test_configs:
123+
print(
124+
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
125+
)
126+
engine = initialize_engine(test_config['model'],
127+
test_config['quantization'],
128+
test_config['lora_repo'])
129+
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
130+
test_prompts = create_test_prompts(lora_path)
131+
process_requests(engine, test_prompts)
132+
133+
# Clean up the GPU memory for the next test
134+
del engine
135+
gc.collect()
136+
torch.cuda.empty_cache()
137+
138+
139+
if __name__ == '__main__':
140+
main()

requirements-dev.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ aiohttp
3535

3636
# Multimodal
3737
pillow
38+
39+
# quantization
40+
bitsandbytes==0.42.0
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
'''Tests whether bitsandbytes computation is enabled correctly.
2+
3+
Run `pytest tests/quantization/test_bitsandbytes.py`.
4+
'''
5+
import pytest
6+
import torch
7+
8+
from vllm import SamplingParams
9+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
10+
11+
capability = torch.cuda.get_device_capability()
12+
capability = capability[0] * 10 + capability[1]
13+
14+
15+
@pytest.mark.skipif(
16+
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
17+
reason='bitsandbytes is not supported on this GPU type.')
18+
def test_load_bnb_model(vllm_runner) -> None:
19+
llm = vllm_runner('huggyllama/llama-7b',
20+
quantization='bitsandbytes',
21+
load_format='bitsandbytes',
22+
enforce_eager=True)
23+
24+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
25+
26+
# check the weights in MLP & SelfAttention are quantized to torch.uint8
27+
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
28+
assert qweight.dtype == torch.uint8, (
29+
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
30+
31+
qweight = model.model.layers[0].mlp.down_proj.qweight
32+
assert qweight.dtype == torch.uint8, (
33+
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
34+
35+
qweight = model.model.layers[0].self_attn.o_proj.qweight
36+
assert qweight.dtype == torch.uint8, (
37+
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
38+
39+
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
40+
assert qweight.dtype == torch.uint8, (
41+
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
42+
43+
# some weights should not be quantized
44+
weight = model.lm_head.weight
45+
assert weight.dtype != torch.uint8, (
46+
'lm_head weight dtype should not be torch.uint8')
47+
48+
weight = model.model.embed_tokens.weight
49+
assert weight.dtype != torch.uint8, (
50+
'embed_tokens weight dtype should not be torch.uint8')
51+
52+
weight = model.model.layers[0].input_layernorm.weight
53+
assert weight.dtype != torch.uint8, (
54+
'input_layernorm weight dtype should not be torch.uint8')
55+
56+
weight = model.model.layers[0].post_attention_layernorm.weight
57+
assert weight.dtype != torch.uint8, (
58+
'input_layernorm weight dtype should not be torch.uint8')
59+
60+
# check the output of the model is expected
61+
sampling_params = SamplingParams(temperature=0.0,
62+
logprobs=1,
63+
prompt_logprobs=1,
64+
max_tokens=8)
65+
66+
prompts = ['That which does not kill us', 'To be or not to be,']
67+
expected_outputs = [
68+
'That which does not kill us makes us stronger.',
69+
'To be or not to be, that is the question.'
70+
]
71+
outputs = llm.generate(prompts, sampling_params=sampling_params)
72+
73+
assert len(outputs) == len(prompts)
74+
75+
for index in range(len(outputs)):
76+
# compare the first line of the output
77+
actual_output = outputs[index][1][0].split('\n', 1)[0]
78+
expected_output = expected_outputs[index].split('\n', 1)[0]
79+
assert actual_output == expected_output, (
80+
f'Expected: {expected_output}, but got: {actual_output}')

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def verify_with_parallel_config(
241241
"must be divisible by pipeline parallel size "
242242
f"({pipeline_parallel_size}).")
243243

244+
if self.quantization == "bitsandbytes" and (
245+
parallel_config.tensor_parallel_size > 1
246+
or parallel_config.pipeline_parallel_size > 1):
247+
raise ValueError(
248+
"BitAndBytes quantization with TP or PP is not supported yet.")
249+
244250
def get_hf_config_sliding_window(self) -> Optional[int]:
245251
"""Get the sliding window size, or None if disabled.
246252
"""
@@ -327,7 +333,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
327333
def get_num_attention_heads(self,
328334
parallel_config: "ParallelConfig") -> int:
329335
return self.hf_text_config.num_attention_heads // \
330-
parallel_config.tensor_parallel_size
336+
parallel_config.tensor_parallel_size
331337

332338
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
333339
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
@@ -487,6 +493,7 @@ class LoadFormat(str, enum.Enum):
487493
DUMMY = "dummy"
488494
TENSORIZER = "tensorizer"
489495
SHARDED_STATE = "sharded_state"
496+
BITSANDBYTES = "bitsandbytes"
490497

491498

492499
@dataclass

vllm/engine/arg_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class EngineArgs:
9292
ngram_prompt_lookup_max: Optional[int] = None
9393
ngram_prompt_lookup_min: Optional[int] = None
9494

95+
qlora_adapter_name_or_path: Optional[str] = None
96+
9597
def __post_init__(self):
9698
if self.tokenizer is None:
9799
self.tokenizer = self.model
@@ -159,7 +161,8 @@ def add_cli_args(
159161
type=str,
160162
default=EngineArgs.load_format,
161163
choices=[
162-
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
164+
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
165+
'bitsandbytes'
163166
],
164167
help='The format of the model weights to load.\n\n'
165168
'* "auto" will try to load the weights in the safetensors format '
@@ -173,7 +176,9 @@ def add_cli_args(
173176
'which is mainly for profiling.\n'
174177
'* "tensorizer" will load the weights using tensorizer from '
175178
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
176-
'section for more information.\n')
179+
'section for more information.\n'
180+
'* "bitsandbytes" will load the weights using bitsandbytes '
181+
'quantization.\n')
177182
parser.add_argument(
178183
'--dtype',
179184
type=str,
@@ -543,7 +548,10 @@ def add_cli_args(
543548
"will also be used in `model_name` tag content of "
544549
"prometheus metrics, if multiple names provided, metrics"
545550
"tag will take the first one.")
546-
551+
parser.add_argument('--qlora-adapter-name-or-path',
552+
type=str,
553+
default=None,
554+
help='Name or path of the QLoRA adapter.')
547555
return parser
548556

549557
@classmethod
@@ -555,6 +563,23 @@ def from_cli_args(cls, args: argparse.Namespace):
555563
return engine_args
556564

557565
def create_engine_config(self, ) -> EngineConfig:
566+
567+
# bitsandbytes quantization needs a specific model loader
568+
# so we make sure the quant method and the load format are consistent
569+
if (self.quantization == "bitsandbytes" or
570+
self.qlora_adapter_name_or_path is not None) and \
571+
self.load_format != "bitsandbytes":
572+
raise ValueError(
573+
"BitsAndBytes quantization and QLoRA adapter only support "
574+
f"'bitsandbytes' load format, but got {self.load_format}")
575+
576+
if (self.load_format == "bitsandbytes" or
577+
self.qlora_adapter_name_or_path is not None) and \
578+
self.quantization != "bitsandbytes":
579+
raise ValueError(
580+
"BitsAndBytes load format and QLoRA adapter only support "
581+
f"'bitsandbytes' quantization, but got {self.quantization}")
582+
558583
device_config = DeviceConfig(self.device)
559584
model_config = ModelConfig(
560585
self.model, self.tokenizer, self.tokenizer_mode,
@@ -622,6 +647,13 @@ def create_engine_config(self, ) -> EngineConfig:
622647
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
623648
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
624649

650+
if self.qlora_adapter_name_or_path is not None and \
651+
self.qlora_adapter_name_or_path != "":
652+
if self.model_loader_extra_config is None:
653+
self.model_loader_extra_config = {}
654+
self.model_loader_extra_config[
655+
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
656+
625657
load_config = LoadConfig(
626658
load_format=self.load_format,
627659
download_dir=self.download_dir,

0 commit comments

Comments
 (0)