You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/features/quantization/fp8.md
+15-71Lines changed: 15 additions & 71 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,24 +19,6 @@ FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada L
19
19
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.
20
20
:::
21
21
22
-
## Quick Start with Online Dynamic Quantization
23
-
24
-
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
25
-
26
-
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
27
-
28
-
```python
29
-
from vllm importLLM
30
-
model = LLM("facebook/opt-125m", quantization="fp8")
31
-
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
32
-
result = model.generate("Hello, my name is")
33
-
print(result[0].outputs[0].text)
34
-
```
35
-
36
-
:::{warning}
37
-
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
38
-
:::
39
-
40
22
## Installation
41
23
42
24
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
@@ -45,12 +27,6 @@ To produce performant FP8 quantized models with vLLM, you'll need to install the
45
27
pip install llmcompressor
46
28
```
47
29
48
-
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
49
-
50
-
```console
51
-
pip install vllm lm-eval==0.4.4
52
-
```
53
-
54
30
## Quantization Process
55
31
56
32
The quantization process involves three main steps:
This package introduces the `AutoFP8ForCausalLM` and `BaseQuantizeConfig` objects for managing how your model will be compressed.
155
-
156
-
## Offline Quantization with Static Activation Scaling Factors
157
-
158
-
You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the `activation_scheme="static"` argument.
159
-
160
-
```python
161
-
from datasets import load_dataset
162
-
from transformers import AutoTokenizer
163
-
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
181
-
model.quantize(examples)
182
-
model.save_quantized(quantized_model_dir)
183
-
```
124
+
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
184
125
185
-
Your model checkpoint with quantized weights and activations should be available at `Meta-Llama-3-8B-Instruct-FP8/`.
186
-
Finally, you can load the quantized model checkpoint directly in vLLM.
126
+
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
187
127
188
128
```python
189
129
from vllm importLLM
190
-
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
191
-
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
130
+
model = LLM("facebook/opt-125m", quantization="fp8")
131
+
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
192
132
result = model.generate("Hello, my name is")
193
133
print(result[0].outputs[0].text)
194
134
```
135
+
136
+
:::{warning}
137
+
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
0 commit comments