You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FP8 KV-cache quantization for PagedAttention (#1400)
* Add most of paged attn kv quant
* It builds a bit
* All the functionality at least
* Small fix
* Add a scale
* Fix bf16 usage
* Make k_v_scale optional
* Collector
* Tweak collection
* Refactor
* Add to apis
* Add cuda impl
* Fix compilation
* Fixes
* Handle ENABLE_FP8
* Format
* Tweak
* Fix scaled_convert usage
* Fix cache_t size
* Fixed scale collection
* Actual fix
* Fix fp8 for CC<8
* Fix the usual String != &str bit (#1483)
Co-authored-by: RageLtMan <rageltman [at] sempervictus>
* chore: `Dockerfile` - Drop runtime rayon thread ENV (#1465)
* chore: Dockerfile - Remove rayon threads env
* chore: Dockerfile - Improve formatting for `apt-get`
* Remove duplicate calls for api_dir_list (#1474)
* Remove duplicate calls for api_dir_list
* Support local cache for api_dir_list
* Fix home folder for metal
* Capitalized
* Fix transient pyo3 dep (#1478)
Co-authored-by: Eric Buehler <[email protected]>
* Fix objc dep with non macos (#1480)
* Fix phi 3/4 + nccl issue (#1481)
* Fix log
* Fix n kv heads
* Fix phi3.5 moe (#1482)
* Fix phi3.5 moe accum device
* Fix again
* Fix again
* Support GLM4 model! (#1437)
* Support GLM4 model
* Mention GLM4 model in ReadMe
* glm4 type hint
* Typo fix
* Fix unsupported chat_template function
* Clippy fix
* Refactor distributed backend (#1484)
* Refactor distributed backend, check power of 2
* Fix compilation
* Cap metal paged attn kv allocation (#1485)
* Better paged attn metal cap (#1486)
* Better paged attn metal cap
* Small fix
* Comment
* Small fix
* Refactor
* Server core: consolidate and unify route handlers and API surface (#1423)
* Start working on consolidating completion and chat_completion underlying implementations
* Move response channel to util mod for now (since it's used with streaming and non streaming)
* More work on consolidating completions and chat completions
* More WIP consolidation of server core handlers
* More WIP consolidation of server core handlers
* More WIP consolidation of server core handlers
* Update docs and restrict completion core visibility
* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this
* Use consistent var name for completions mod
* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example
* Code formatting
* Typo
* Sync fork
* Sync fork
* Docs example fix
* Support qwen3 gguf (#1488)
* Add qwen3 gguf
* Template fixup
* Make bos/eos token IDs optional (#1493)
* Remove python deps from CUDA dockerfiles (#1487)
* Handle USE_FP8 for cuda
* Fix cuda warn
* Add readme
* Saturating sub in sequence state
---------
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>
Our PagedAttention implementation has 2 inputs: GPU KV cache memory size, and block size. This enables you to have fine-tuned control over the available context length, by configuring the available memory for KV cache. When using a CUDA device, PagedAttention is actiated by default but can be disabled with `no_paged_attn` for Python or `no-paged-attn` for the CLI tools.
8
8
9
+
## KV Cache Quantization
10
+
11
+
PagedAttention now supports KV cache quantization to reduce memory usage and potentially improve performance. The KV cache can be quantized to FP8 (F8E4M3 format) instead of using the model's native dtype, significantly reducing memory requirements while maintaining model quality.
12
+
13
+
**Available cache types:**
14
+
-`auto` (default): Uses the model's native dtype for KV cache
15
+
-`f8e4m3`: Quantizes KV cache to 8-bit floating point (E4M3 format)
16
+
17
+
When using FP8 quantization, the memory usage for KV cache is approximately halved compared to FP16, allowing for longer context lengths with the same GPU memory allocation.
18
+
9
19
> Note: The default block size if not specified is 32.
10
20
11
21
> Note: if OOM occurs (this can be caused by a variety of factors including adapter activation, re-ISQ, and others), it is likely because the PagedAttention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount or usage percentage (recommended) or disable paged attention entirely for a dynamically allocated cache.
@@ -40,6 +50,8 @@ the prefill phase.
40
50
41
51
Add the `--pa-gpu-mem`/`--pa-gpu-mem-usage` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type.
42
52
53
+
To enable KV cache quantization, use the `--pa-cache-type` parameter with either `auto` (default) or `f8e4m3`.
54
+
43
55
```
44
56
cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --isq Q4K plain -m microsoft/Phi-3-mini-128k-instruct
45
57
```
@@ -48,6 +60,11 @@ cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --i
48
60
cargo run --release --features cuda -- -i --pa-gpu-mem-usage .95 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
49
61
```
50
62
63
+
Example with FP8 KV cache quantization:
64
+
```
65
+
cargo run --release --features metal -- -i --pa-gpu-mem 4096 --pa-blk-size 32 --pa-cache-type f8e4m3 plain -m microsoft/Phi-3-mini-128k-instruct
66
+
```
67
+
51
68
## Using the Rust API
52
69
You can find this example [here](../mistralrs/examples/paged_attn/main.rs).
0 commit comments