Skip to content

Commit ade989a

Browse files
mgoinjimpang
authored andcommitted
[Doc] Add documentation for FP8 W8A8 (vllm-project#5388)
1 parent a0b26ad commit ade989a

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Documentation
9696
:caption: Quantization
9797

9898
quantization/auto_awq
99+
quantization/fp8
99100
quantization/fp8_e5m2_kvcache
100101
quantization/fp8_e4m3_kvcache
101102

docs/source/quantization/fp8.rst

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
.. _fp8:
2+
3+
FP8
4+
==================
5+
6+
vLLM supports FP8 (8-bit floating point) computation using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x. Currently, only Hopper and Ada Lovelace GPUs are supported. Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
7+
8+
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
9+
10+
The FP8 types typically supported in hardware have two distinct representations, each useful in different scenarios:
11+
12+
- **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and ``nan``.
13+
- **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf``, and ``nan``. The tradeoff for the increased dynamic range is lower precision of the stored values.
14+
15+
Quick Start with Online Dynamic Quantization
16+
-------------------------------------
17+
18+
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying ``--quantization="fp8"`` in the command line or setting ``quantization="fp8"`` in the LLM constructor.
19+
20+
In this mode, all Linear modules (except for the final ``lm_head``) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
21+
22+
.. code-block:: python
23+
24+
from vllm import LLM
25+
model = LLM("facebook/opt-125m", quantization="fp8")
26+
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
27+
result = model.generate("Hello, my name is")
28+
29+
.. warning::
30+
31+
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
32+
33+
Offline Quantization
34+
--------------------
35+
36+
For offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.
37+
38+
.. code-block:: bash
39+
40+
git clone https://github.com/neuralmagic/AutoFP8.git
41+
pip install -e AutoFP8
42+
43+
This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed.
44+
45+
Offline Quantization with Dynamic Activation Scaling Factors
46+
------------------------------------------------------------
47+
48+
You can use AutoFP8 to produce checkpoints with their weights quantized to FP8 ahead of time and let vLLM handle calculating dynamic scales for the activations at runtime for maximum accuracy. You can enable this with the ``activation_scheme="dynamic"`` argument.
49+
50+
.. warning::
51+
52+
Please note that although this mode doesn't give you better performance, it reduces memory footprint compared to online quantization.
53+
54+
.. code-block:: python
55+
56+
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
57+
58+
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
59+
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-Dynamic"
60+
61+
# Define quantization config with static activation scales
62+
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic")
63+
# For dynamic activation scales, there is no need for calbration examples
64+
examples = []
65+
66+
# Load the model, quantize, and save checkpoint
67+
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
68+
model.quantize(examples)
69+
model.save_quantized(quantized_model_dir)
70+
71+
In the output of the above script, you should be able to see the quantized Linear modules (FP8DynamicLinear) replaced in the model definition.
72+
Note that the ``lm_head`` Linear module at the end is currently skipped by default.
73+
74+
.. code-block:: text
75+
76+
LlamaForCausalLM(
77+
(model): LlamaModel(
78+
(embed_tokens): Embedding(128256, 4096)
79+
(layers): ModuleList(
80+
(0-31): 32 x LlamaDecoderLayer(
81+
(self_attn): LlamaSdpaAttention(
82+
(q_proj): FP8DynamicLinear()
83+
(k_proj): FP8DynamicLinear()
84+
(v_proj): FP8DynamicLinear()
85+
(o_proj): FP8DynamicLinear()
86+
(rotary_emb): LlamaRotaryEmbedding()
87+
)
88+
(mlp): LlamaMLP(
89+
(gate_proj): FP8DynamicLinear()
90+
(up_proj): FP8DynamicLinear()
91+
(down_proj): FP8DynamicLinear()
92+
(act_fn): SiLU()
93+
)
94+
(input_layernorm): LlamaRMSNorm()
95+
(post_attention_layernorm): LlamaRMSNorm()
96+
)
97+
)
98+
(norm): LlamaRMSNorm()
99+
)
100+
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
101+
)
102+
Saving the model to Meta-Llama-3-8B-Instruct-FP8-Dynamic
103+
104+
Your model checkpoint with quantized weights should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
105+
We can see that the weights are smaller than the original BF16 precision.
106+
107+
.. code-block:: bash
108+
109+
ls -lh Meta-Llama-3-8B-Instruct-FP8-Dynamic/
110+
total 8.5G
111+
-rw-rw-r-- 1 user user 869 Jun 7 14:43 config.json
112+
-rw-rw-r-- 1 user user 194 Jun 7 14:43 generation_config.json
113+
-rw-rw-r-- 1 user user 4.7G Jun 7 14:43 model-00001-of-00002.safetensors
114+
-rw-rw-r-- 1 user user 3.9G Jun 7 14:43 model-00002-of-00002.safetensors
115+
-rw-rw-r-- 1 user user 43K Jun 7 14:43 model.safetensors.index.json
116+
-rw-rw-r-- 1 user user 296 Jun 7 14:43 special_tokens_map.json
117+
-rw-rw-r-- 1 user user 50K Jun 7 14:43 tokenizer_config.json
118+
-rw-rw-r-- 1 user user 8.7M Jun 7 14:43 tokenizer.json
119+
120+
Finally, you can load the quantized model checkpoint directly in vLLM.
121+
122+
.. code-block:: python
123+
124+
from vllm import LLM
125+
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8-Dynamic/")
126+
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
127+
result = model.generate("Hello, my name is")
128+
129+
Offline Quantization with Static Activation Scaling Factors
130+
-----------------------------------------------------------
131+
132+
For the best inference performance, you can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.
133+
134+
.. code-block:: python
135+
136+
from datasets import load_dataset
137+
from transformers import AutoTokenizer
138+
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
139+
140+
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
141+
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
142+
143+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
144+
tokenizer.pad_token = tokenizer.eos_token
145+
146+
# Load and tokenize 512 dataset samples for calibration of activation scales
147+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
148+
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
149+
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
150+
151+
# Define quantization config with static activation scales
152+
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
153+
154+
# Load the model, quantize, and save checkpoint
155+
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
156+
model.quantize(examples)
157+
model.save_quantized(quantized_model_dir)
158+
159+
Your model checkpoint with quantized weights and activations should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
160+
Finally, you can load the quantized model checkpoint directly in vLLM.
161+
162+
.. code-block:: python
163+
164+
from vllm import LLM
165+
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
166+
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
167+
result = model.generate("Hello, my name is")
168+
169+
FP8 checkpoint structure explanation
170+
-----------------------------------------------------------
171+
172+
Here we detail the structure for the FP8 checkpoints.
173+
174+
The following is necessary to be present in the model's ``config.json``:
175+
176+
.. code-block:: yaml
177+
"quantization_config": {
178+
"quant_method": "fp8",
179+
"activation_scheme": "static" or "dynamic"
180+
},
181+
182+
183+
Each quantized layer in the state_dict will have these tensors:
184+
185+
* If the config has `"activation_scheme": "static"`:
186+
187+
.. code-block:: text
188+
model.layers.0.mlp.down_proj.weight < F8_E4M3
189+
model.layers.0.mlp.down_proj.input_scale < F32
190+
model.layers.0.mlp.down_proj.weight_scale < F32
191+
192+
* If the config has `"activation_scheme": "dynamic"`:
193+
194+
.. code-block:: text
195+
model.layers.0.mlp.down_proj.weight < F8_E4M3
196+
model.layers.0.mlp.down_proj.weight_scale < F32
197+
198+
199+
Additionally, there can be `FP8 kv-cache scaling factors <https://github.com/vllm-project/vllm/pull/4893>`_ contained within quantized checkpoints specified through the ``.kv_scale`` parameter present on the Attention Module, such as:
200+
201+
.. code-block:: text
202+
model.layers.0.self_attn.kv_scale < F32

0 commit comments

Comments
 (0)