Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 9e0c3da

Browse files
thesuesmgoin
authored andcommitted
[bitsandbytes]: support read bnb pre-quantized model (vllm-project#5753)
Co-authored-by: Michael Goin <[email protected]>
1 parent 573c43f commit 9e0c3da

File tree

8 files changed

+142
-39
lines changed

8 files changed

+142
-39
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Documentation
105105

106106
quantization/supported_hardware
107107
quantization/auto_awq
108+
quantization/bnb
108109
quantization/fp8
109110
quantization/fp8_e5m2_kvcache
110111
quantization/fp8_e4m3_kvcache

docs/source/quantization/bnb.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
.. _bits_and_bytes:
2+
3+
BitsAndBytes
4+
==================
5+
6+
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
7+
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
8+
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
9+
10+
Below are the steps to utilize BitsAndBytes with vLLM.
11+
12+
.. code-block:: console
13+
14+
$ pip install bitsandbytes>=0.42.0
15+
16+
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
17+
18+
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
19+
And usually, these repositories have a config.json file that includes a quantization_config section.
20+
21+
Read quantized checkpoint.
22+
--------------------------
23+
24+
.. code-block:: python
25+
26+
from vllm import LLM
27+
import torch
28+
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
29+
model_id = "unsloth/tinyllama-bnb-4bit"
30+
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
31+
quantization="bitsandbytes", load_format="bitsandbytes")
32+
33+
Inflight quantization: load as 4bit quantization
34+
------------------------------------------------
35+
36+
.. code-block:: python
37+
38+
from vllm import LLM
39+
import torch
40+
model_id = "huggyllama/llama-7b"
41+
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
42+
quantization="bitsandbytes", load_format="bitsandbytes")
43+

tests/quantization/test_bitsandbytes.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88
from tests.quantization.utils import is_quant_method_supported
99
from vllm import SamplingParams
1010

11+
models_to_test = [
12+
('huggyllama/llama-7b', 'quantize model inflight'),
13+
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
14+
]
15+
1116

1217
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
1318
reason='bitsandbytes is not supported on this GPU type.')
14-
def test_load_bnb_model(vllm_runner) -> None:
15-
with vllm_runner('huggyllama/llama-7b',
19+
@pytest.mark.parametrize("model_name, description", models_to_test)
20+
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
21+
with vllm_runner(model_name,
1622
quantization='bitsandbytes',
1723
load_format='bitsandbytes',
1824
enforce_eager=True) as llm:
19-
2025
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2126

2227
# check the weights in MLP & SelfAttention are quantized to torch.uint8
@@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
6570
'To be or not to be, that is the question.'
6671
]
6772
outputs = llm.generate(prompts, sampling_params=sampling_params)
68-
6973
assert len(outputs) == len(prompts)
7074

7175
for index in range(len(outputs)):
7276
# compare the first line of the output
7377
actual_output = outputs[index][1][0].split('\n', 1)[0]
7478
expected_output = expected_outputs[index].split('\n', 1)[0]
79+
80+
assert len(actual_output) >= len(expected_output), (
81+
f'Actual {actual_output} should be larger than or equal to '
82+
f'expected {expected_output}')
83+
actual_output = actual_output[:len(expected_output)]
84+
7585
assert actual_output == expected_output, (
7686
f'Expected: {expected_output}, but got: {actual_output}')

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ class LoadConfig:
582582
mainly for profiling.
583583
"tensorizer" will use CoreWeave's tensorizer library for
584584
fast weight loading.
585+
"bitsandbytes" will load nf4 type weights.
585586
"""
586587

587588
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO

vllm/engine/arg_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,8 @@ def create_engine_config(self, ) -> EngineConfig:
621621
# bitsandbytes quantization needs a specific model loader
622622
# so we make sure the quant method and the load format are consistent
623623
if (self.quantization == "bitsandbytes" or
624-
self.qlora_adapter_name_or_path is not None) and \
625-
self.load_format != "bitsandbytes":
624+
self.qlora_adapter_name_or_path is not None) and \
625+
self.load_format != "bitsandbytes":
626626
raise ValueError(
627627
"BitsAndBytes quantization and QLoRA adapter only support "
628628
f"'bitsandbytes' load format, but got {self.load_format}")

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
1515
Reference: https://arxiv.org/abs/2305.14314
1616
"""
1717

18-
def __init__(
19-
self,
20-
adapter_name_or_path: str,
21-
target_modules: List[str],
22-
) -> None:
23-
24-
self.adapter_name_or_path = adapter_name_or_path
25-
self.target_modules = target_modules
18+
def __init__(self, ) -> None:
19+
pass
2620

2721
def __repr__(self) -> str:
28-
return (
29-
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
30-
)
22+
return "BitsAndBytesConfig"
3123

3224
@classmethod
3325
def get_name(self) -> str:
@@ -49,16 +41,7 @@ def get_config_filenames() -> List[str]:
4941

5042
@classmethod
5143
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
52-
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
53-
default_target_modules = [
54-
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
55-
"o_proj"
56-
]
57-
if adapter_name == "":
58-
target_modules = default_target_modules
59-
else:
60-
target_modules = cls.get_from_keys(config, ["target_modules"])
61-
return cls(adapter_name, target_modules)
44+
return cls()
6245

6346
def get_quant_method(
6447
self,

vllm/model_executor/model_loader/loader.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,14 @@ def _prepare_weights(self, model_name_or_path: str,
685685

686686
return hf_weights_files, matched_pattern == "*.safetensors"
687687

688+
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
689+
if use_safetensors:
690+
return safetensors_weights_iterator(hf_weights_files)
691+
else:
692+
return pt_weights_iterator(hf_weights_files)
693+
688694
def _get_quantized_weights_iterator(
689-
self, model_name_or_path: str, revision: Optional[str]
695+
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
690696
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
691697
Any]]:
692698
"""Get an iterator to the model weights with bitsandbytes quantization,
@@ -695,6 +701,7 @@ def _get_quantized_weights_iterator(
695701
# only load the bitsandbytes module when needed
696702
try:
697703
import bitsandbytes
704+
from bitsandbytes.functional import QuantState
698705
if bitsandbytes.__version__ < "0.42.0":
699706
raise ImportError("bitsandbytes version is wrong. Please "
700707
"install bitsandbytes>=0.42.0.")
@@ -708,17 +715,63 @@ def _get_quantized_weights_iterator(
708715
model_name_or_path, revision)
709716

710717
quant_state_dict = {}
711-
if use_safetensors:
712-
weight_iterator = safetensors_weights_iterator(hf_weights_files)
713-
else:
714-
weight_iterator = pt_weights_iterator(hf_weights_files)
715718

716-
def generator():
719+
def quantized_checkpoint() -> Generator:
720+
# First iterate over all quant state weights
721+
weight_iterator = self._hf_weight_iter(hf_weights_files,
722+
use_safetensors)
723+
temp_state_dict = {}
717724
for weight_name, weight_tensor in weight_iterator:
725+
if weight_name.endswith(".weight"):
726+
continue
727+
# TODO: only nf4 quantization is supported for now
728+
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
729+
raise NotImplementedError(
730+
"Only bitsandbytes_nf4 quantization"
731+
f"is supported for now. {weight_name} is fp4 quantized"
732+
)
733+
temp_state_dict[weight_name] = weight_tensor
734+
735+
# Closure to parse quant_state for each prequant weight
736+
def _parse_quant_state(param_name: str,
737+
temp_state_dict: Dict) -> QuantState:
738+
quant_state = {}
739+
for k in temp_state_dict:
740+
if param_name + "." in k:
741+
quant_state[k] = temp_state_dict[k]
742+
# bitsandbytes library requires
743+
# weight.quant_state.bitsandbytes__nf4 in CPU
744+
quant_state[param_name +
745+
".quant_state.bitsandbytes__nf4"] = quant_state[
746+
param_name +
747+
".quant_state.bitsandbytes__nf4"].cpu().data
748+
return QuantState.from_dict(quant_state, device="cuda")
749+
750+
# Second iterate over all prequant and normal weights
751+
# pre quantized weights would have a quant_state
752+
for weight_name, weight_tensor in self._hf_weight_iter(
753+
hf_weights_files, use_safetensors):
754+
# Filter out all weights whose suffix is not ".weight"
755+
if not weight_name.endswith(".weight"):
756+
continue
757+
if weight_name + ".quant_state.bitsandbytes__nf4" \
758+
in temp_state_dict:
759+
quant_state = _parse_quant_state(weight_name,
760+
temp_state_dict)
761+
weight_name = weight_name.replace(".weight", ".qweight")
762+
quant_state_dict[weight_name] = quant_state
763+
yield weight_name.replace(".weight",
764+
".qweight"), weight_tensor
765+
else:
766+
yield weight_name, weight_tensor
767+
768+
def generator() -> Generator:
769+
for weight_name, weight_tensor in self._hf_weight_iter(
770+
hf_weights_files, use_safetensors):
718771
if any(target_module in weight_name
719772
for target_module in self.target_modules):
720773
weight_name = weight_name.replace(".weight", ".qweight")
721-
# bitsandbytes requires data in GPU
774+
# bitsandbytes requires data in GPU
722775
loaded_weight = weight_tensor.cuda().data
723776
with set_default_torch_dtype(torch.float32):
724777
processed_weight, quant_state = quantize_4bit(
@@ -732,6 +785,8 @@ def generator():
732785

733786
yield weight_name, processed_weight
734787

788+
if pre_quant:
789+
return quantized_checkpoint(), quant_state_dict
735790
return generator(), quant_state_dict
736791

737792
def _load_weights(self, model_config: ModelConfig,
@@ -749,12 +804,21 @@ def _load_weights(self, model_config: ModelConfig,
749804
logger.info("Loading weights with BitsAndBytes quantization. "
750805
" May take a while ...")
751806

752-
qweight_iterator, quant_state_dict = (
753-
self._get_quantized_weights_iterator(model_config.model,
754-
model_config.revision))
807+
is_quantized_checkpoint = False
808+
quant_config = getattr(model_config.hf_config, "quantization_config",
809+
None)
810+
if quant_config is not None and quant_config.get(
811+
'quant_method') == "bitsandbytes":
812+
is_quantized_checkpoint = True
813+
814+
qweight_iterator, quant_state_dict = \
815+
self._get_quantized_weights_iterator(
816+
model_config.model, model_config.revision, is_quantized_checkpoint)
755817

756818
model.load_weights(qweight_iterator)
757819

820+
torch.cuda.empty_cache()
821+
758822
param_dict = dict(model.named_parameters())
759823
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
760824
for quant_param_name in quant_state_dict:
@@ -792,9 +856,9 @@ def _load_weights(self, model_config: ModelConfig,
792856
f"pack_factor not set for parameter {param_name}.")
793857

794858
num_elements = [0] * len(quant_states)
795-
for seq, quant_state in enumerate(quant_states.items()):
859+
for seq, quant_state in quant_states.items():
796860
num_elements[seq] = math.prod(
797-
quant_state[1].shape) // pack_ratio
861+
quant_state.shape) // pack_ratio
798862

799863
offsets = np.concatenate(([0], np.cumsum(num_elements)))
800864
set_weight_attrs(param, {"bnb_shard_offsets": offsets})

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def convert_bin_to_safetensor_file(
117117
# TODO(woosuk): Move this to other place.
118118
def get_quant_config(model_config: ModelConfig,
119119
load_config: LoadConfig) -> QuantizationConfig:
120+
120121
quant_cls = get_quantization_config(model_config.quantization)
121122
# Read the quantization config from the HF model config, if available.
122123
hf_quant_config = getattr(model_config.hf_config, "quantization_config",

0 commit comments

Comments
 (0)