Skip to content

Commit 1ad6488

Browse files
EricLBuehlerEric Buehlersempervictuspolaratheneguoqingbao
authored
FP8 KV-cache quantization for PagedAttention (#1400)
* Add most of paged attn kv quant * It builds a bit * All the functionality at least * Small fix * Add a scale * Fix bf16 usage * Make k_v_scale optional * Collector * Tweak collection * Refactor * Add to apis * Add cuda impl * Fix compilation * Fixes * Handle ENABLE_FP8 * Format * Tweak * Fix scaled_convert usage * Fix cache_t size * Fixed scale collection * Actual fix * Fix fp8 for CC<8 * Fix the usual String != &str bit (#1483) Co-authored-by: RageLtMan <rageltman [at] sempervictus> * chore: `Dockerfile` - Drop runtime rayon thread ENV (#1465) * chore: Dockerfile - Remove rayon threads env * chore: Dockerfile - Improve formatting for `apt-get` * Remove duplicate calls for api_dir_list (#1474) * Remove duplicate calls for api_dir_list * Support local cache for api_dir_list * Fix home folder for metal * Capitalized * Fix transient pyo3 dep (#1478) Co-authored-by: Eric Buehler <[email protected]> * Fix objc dep with non macos (#1480) * Fix phi 3/4 + nccl issue (#1481) * Fix log * Fix n kv heads * Fix phi3.5 moe (#1482) * Fix phi3.5 moe accum device * Fix again * Fix again * Support GLM4 model! (#1437) * Support GLM4 model * Mention GLM4 model in ReadMe * glm4 type hint * Typo fix * Fix unsupported chat_template function * Clippy fix * Refactor distributed backend (#1484) * Refactor distributed backend, check power of 2 * Fix compilation * Cap metal paged attn kv allocation (#1485) * Better paged attn metal cap (#1486) * Better paged attn metal cap * Small fix * Comment * Small fix * Refactor * Server core: consolidate and unify route handlers and API surface (#1423) * Start working on consolidating completion and chat_completion underlying implementations * Move response channel to util mod for now (since it's used with streaming and non streaming) * More work on consolidating completions and chat completions * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * Update docs and restrict completion core visibility * CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this * Use consistent var name for completions mod * Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub Make lib.rs example compile checked and update example * Code formatting * Typo * Sync fork * Sync fork * Docs example fix * Support qwen3 gguf (#1488) * Add qwen3 gguf * Template fixup * Make bos/eos token IDs optional (#1493) * Remove python deps from CUDA dockerfiles (#1487) * Handle USE_FP8 for cuda * Fix cuda warn * Add readme * Saturating sub in sequence state --------- Co-authored-by: Eric Buehler <[email protected]> Co-authored-by: RageLtMan <[email protected]> Co-authored-by: Brennan Kinney <[email protected]> Co-authored-by: Guoqing Bao <[email protected]> Co-authored-by: Matthew Haynes <[email protected]>
1 parent c9d0a0e commit 1ad6488

File tree

36 files changed

+1767
-342
lines changed

36 files changed

+1767
-342
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
205205
- [GGML & GGUF support](docs/QUANTS.md): 2–8 bit
206206
- [GPTQ](docs/QUANTS.md), [AWQ](scripts/convert_awq_marlin.py), [AFQ](docs/QUANTS.md#afq), [HQQ](docs/QUANTS.md#hqq), [FP8](docs/QUANTS.md), [BNB](https://github.com/TimDettmers/bitsandbytes) (int8/fp4/nf4)
207207
- ⭐ Auto-select the fastest quant method
208+
- [KV cache quantization](docs/PAGED_ATTENTION.md#kv-cache-quantization)
208209

209210
4. **Flexibility**
210211
- [LoRA](docs/ADAPTER_MODELS.md) & [X-LoRA](docs/ADAPTER_MODELS.md) adapters with weight merging

docs/PAGED_ATTENTION.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ Mistral.rs supports PagedAttention ([paper here](https://arxiv.org/abs/2309.0618
66

77
Our PagedAttention implementation has 2 inputs: GPU KV cache memory size, and block size. This enables you to have fine-tuned control over the available context length, by configuring the available memory for KV cache. When using a CUDA device, PagedAttention is actiated by default but can be disabled with `no_paged_attn` for Python or `no-paged-attn` for the CLI tools.
88

9+
## KV Cache Quantization
10+
11+
PagedAttention now supports KV cache quantization to reduce memory usage and potentially improve performance. The KV cache can be quantized to FP8 (F8E4M3 format) instead of using the model's native dtype, significantly reducing memory requirements while maintaining model quality.
12+
13+
**Available cache types:**
14+
- `auto` (default): Uses the model's native dtype for KV cache
15+
- `f8e4m3`: Quantizes KV cache to 8-bit floating point (E4M3 format)
16+
17+
When using FP8 quantization, the memory usage for KV cache is approximately halved compared to FP16, allowing for longer context lengths with the same GPU memory allocation.
18+
919
> Note: The default block size if not specified is 32.
1020
1121
> Note: if OOM occurs (this can be caused by a variety of factors including adapter activation, re-ISQ, and others), it is likely because the PagedAttention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount or usage percentage (recommended) or disable paged attention entirely for a dynamically allocated cache.
@@ -40,6 +50,8 @@ the prefill phase.
4050

4151
Add the `--pa-gpu-mem`/`--pa-gpu-mem-usage` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type.
4252

53+
To enable KV cache quantization, use the `--pa-cache-type` parameter with either `auto` (default) or `f8e4m3`.
54+
4355
```
4456
cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --isq Q4K plain -m microsoft/Phi-3-mini-128k-instruct
4557
```
@@ -48,6 +60,11 @@ cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --i
4860
cargo run --release --features cuda -- -i --pa-gpu-mem-usage .95 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
4961
```
5062

63+
Example with FP8 KV cache quantization:
64+
```
65+
cargo run --release --features metal -- -i --pa-gpu-mem 4096 --pa-blk-size 32 --pa-cache-type f8e4m3 plain -m microsoft/Phi-3-mini-128k-instruct
66+
```
67+
5168
## Using the Rust API
5269
You can find this example [here](../mistralrs/examples/paged_attn/main.rs).
5370

@@ -94,6 +111,33 @@ async fn main() -> Result<()> {
94111
}
95112
```
96113

114+
Example with FP8 KV cache quantization:
115+
```rust
116+
use anyhow::Result;
117+
use mistralrs::{
118+
IsqType, MemoryGpuConfig, PagedAttentionMetaBuilder, PagedCacheType,
119+
TextMessageRole, TextMessages, TextModelBuilder,
120+
};
121+
122+
#[tokio::main]
123+
async fn main() -> Result<()> {
124+
let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct")
125+
.with_isq(IsqType::Q8_0)
126+
.with_logging()
127+
.with_paged_attn(|| {
128+
PagedAttentionMetaBuilder::default()
129+
.with_block_size(32)
130+
.with_gpu_memory(MemoryGpuConfig::ContextSize(1024))
131+
.with_cache_type(PagedCacheType::F8E4M3)
132+
.build()
133+
})?
134+
.build()
135+
.await?;
136+
137+
// ... rest of the code remains the same
138+
}
139+
```
140+
97141
## Using the Python API
98142
```py
99143
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture
@@ -121,4 +165,21 @@ res = runner.send_chat_completion_request(
121165
)
122166
print(res.choices[0].message.content)
123167
print(res.usage)
168+
```
169+
170+
Example with FP8 KV cache quantization:
171+
```py
172+
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture, PagedCacheType
173+
174+
runner = Runner(
175+
which=Which.Plain(
176+
model_id="mistralai/Mistral-7B-Instruct-v0.1",
177+
arch=Architecture.Mistral,
178+
),
179+
pa_gpu_mem = 4096,
180+
pa_blk_size = 32,
181+
pa_cache_type = PagedCacheType.F8E4M3,
182+
)
183+
184+
# ... rest of the code remains the same
124185
```

mistralrs-bench/src/main.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use mistralrs_core::{
55
get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
66
parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
77
DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8-
MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request, RequestMessage,
9-
Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
8+
MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9+
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
1010
};
1111
use std::sync::Arc;
1212
use std::{fmt::Display, num::NonZeroUsize};
@@ -265,6 +265,10 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
265265
let _ = rx.blocking_recv();
266266
}
267267

268+
fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
269+
s.parse()
270+
}
271+
268272
#[derive(Parser)]
269273
#[command(version, about, long_about = None)]
270274
struct Args {
@@ -323,6 +327,11 @@ struct Args {
323327
#[arg(long = "pa-ctxt-len")]
324328
paged_ctxt_len: Option<usize>,
325329

330+
/// PagedAttention KV cache type (auto or f8e4m3).
331+
/// Defaults to `auto`.
332+
#[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
333+
cache_type: Option<PagedCacheType>,
334+
326335
/// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
327336
/// PagedAttention is only supported on CUDA and is always automatically activated.
328337
#[arg(long = "pa-blk-size")]
@@ -448,28 +457,33 @@ async fn main() -> anyhow::Result<()> {
448457
block_size,
449458
512,
450459
MemoryGpuConfig::ContextSize(max_seq_len),
460+
args.cache_type.unwrap_or_default(),
451461
)?),
452462
(block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
453463
block_size,
454464
512,
455465
MemoryGpuConfig::ContextSize(ctxt),
466+
args.cache_type.unwrap_or_default(),
456467
)?),
457468
(block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
458469
block_size,
459470
512,
460471
MemoryGpuConfig::Utilization(f),
472+
args.cache_type.unwrap_or_default(),
461473
)?),
462474
(block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
463475
block_size,
464476
512,
465477
MemoryGpuConfig::MbAmount(m),
478+
args.cache_type.unwrap_or_default(),
466479
)?),
467480
(block_size, Some(_m), Some(f), None, true, false) => {
468481
info!("Both memory size, and usage were specified, defaulting to the usage value.");
469482
Some(PagedAttentionConfig::new(
470483
block_size,
471484
512,
472485
MemoryGpuConfig::Utilization(f),
486+
args.cache_type.unwrap_or_default(),
473487
)?)
474488
}
475489
(block_size, Some(_m), None, Some(ctxt), true, false) => {
@@ -478,6 +492,7 @@ async fn main() -> anyhow::Result<()> {
478492
block_size,
479493
512,
480494
MemoryGpuConfig::ContextSize(ctxt),
495+
args.cache_type.unwrap_or_default(),
481496
)?)
482497
}
483498
(block_size, None, Some(f), Some(_ctxt), true, false) => {
@@ -486,6 +501,7 @@ async fn main() -> anyhow::Result<()> {
486501
block_size,
487502
512,
488503
MemoryGpuConfig::Utilization(f),
504+
args.cache_type.unwrap_or_default(),
489505
)?)
490506
}
491507
(_, _, _, _, _, _) => None,

mistralrs-core/src/dummy_paged_attention/block_engine.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ impl LogicalTokenBlock {
5151
self.tokens.pop();
5252
self.num_tokens -= 1;
5353
}
54+
55+
pub fn toks(&self) -> &[usize] {
56+
&self.tokens
57+
}
5458
}
5559

5660
impl Hash for LogicalTokenBlock {
@@ -272,9 +276,6 @@ impl BlockEngine {
272276
// If there are prefill physical blocks, use those here.
273277
if let Some(physical_blocks_prefill) = seq.take_physical_blocks_prefill() {
274278
let mut block_table = physical_blocks_prefill.clone();
275-
for block in &mut block_table {
276-
block.deref_mut().refcount = 1;
277-
}
278279
let n_extra_blocks = seq.logical_token_blocks().len() - block_table.len();
279280
for _ in 0..n_extra_blocks {
280281
block_table.push(self.gpu_allocator.allocate());

mistralrs-core/src/dummy_paged_attention/cache_engine.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,50 @@
11
use std::{
22
collections::HashMap,
3+
str::FromStr,
34
sync::{Arc, Mutex, MutexGuard},
45
};
56

67
use candle_core::{DType, Device, Result, Tensor};
8+
use serde::{Deserialize, Serialize};
79

810
use super::config::ModelConfigLike;
911

12+
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
13+
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
14+
pub enum PagedCacheType {
15+
#[default]
16+
Auto,
17+
F8E4M3,
18+
}
19+
20+
impl PagedCacheType {
21+
pub fn to_dtype(&self, act_dtype: DType) -> DType {
22+
match self {
23+
PagedCacheType::F8E4M3 => DType::F8E4M3,
24+
PagedCacheType::Auto => act_dtype,
25+
}
26+
}
27+
}
28+
29+
impl FromStr for PagedCacheType {
30+
type Err = String;
31+
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
32+
match s {
33+
"auto" => Ok(Self::Auto),
34+
"f8e4m3" => Ok(Self::F8E4M3),
35+
other => Err(format!(
36+
"Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
37+
)),
38+
}
39+
}
40+
}
41+
1042
#[derive(Clone, Debug)]
1143
pub struct CacheConfig {
1244
pub block_size: usize,
1345
pub num_gpu_blocks: usize,
1446
pub num_cpu_blocks: usize,
47+
pub cache_type: PagedCacheType,
1548
}
1649

1750
pub type KVCache = (Tensor, Tensor);

mistralrs-core/src/dummy_paged_attention/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub const _PAD_SLOT_ID: i64 = -1;
1313

1414
pub use block_engine::{BlockEngine, BlockTables, LogicalTokenBlock, PhysicalTokenBlock};
1515
pub use block_engine_sequence::BlockEngineSequence;
16-
pub use cache_engine::{CacheConfig, CacheEngine};
16+
pub use cache_engine::{CacheConfig, CacheEngine, PagedCacheType};
1717
use candle_core::{DType, Device};
1818
pub use config::{ModelConfigLike, ModelConfigMetadata};
1919
pub use layers::PagedAttention;
@@ -32,18 +32,21 @@ pub struct PagedAttentionConfig {
3232
pub(crate) block_size: Option<usize>,
3333
pub(crate) mem_cpu: usize,
3434
pub(crate) mem_gpu: MemoryGpuConfig,
35+
pub(crate) cache_type: PagedCacheType,
3536
}
3637

3738
impl PagedAttentionConfig {
3839
pub fn new(
3940
block_size: Option<usize>,
4041
mem_cpu: usize,
4142
mem_gpu: MemoryGpuConfig,
43+
cache_type: PagedCacheType,
4244
) -> anyhow::Result<Self> {
4345
Ok(Self {
4446
block_size,
4547
mem_cpu,
4648
mem_gpu,
49+
cache_type,
4750
})
4851
}
4952
}
@@ -97,6 +100,7 @@ pub fn calculate_cache_config(
97100
mem_cpu: usize,
98101
block_size: Option<usize>,
99102
dtype: DType,
103+
cache_type: PagedCacheType,
100104
config: &dyn ModelConfigLike,
101105
device: &Device,
102106
layer_devices: &[Option<Device>],
@@ -106,6 +110,7 @@ pub fn calculate_cache_config(
106110
if !SUPPORTED_BLOCK_SIZE.contains(&block_size) {
107111
anyhow::bail!("Block size must be in {SUPPORTED_BLOCK_SIZE:?}, got {block_size}");
108112
}
113+
let dtype = cache_type.to_dtype(dtype);
109114
let dtype_size = dtype.size_in_bytes();
110115

111116
let mut min_mem_gpu = usize::MAX;
@@ -148,5 +153,6 @@ pub fn calculate_cache_config(
148153
block_size,
149154
num_gpu_blocks,
150155
num_cpu_blocks,
156+
cache_type,
151157
})
152158
}

mistralrs-core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ pub use mistralrs_mcp::{
8989
McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
9090
};
9191
pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
92-
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
92+
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
9393
pub use pipeline::{
9494
chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
9595
AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,

mistralrs-core/src/paged_attention/cache_engine.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,53 @@
11
use std::{
22
collections::HashMap,
3+
str::FromStr,
34
sync::{Arc, Mutex, MutexGuard},
45
};
56

67
use candle_core::{
78
from_storage_no_op, DType, Device, MetalStorage, Result, Shape, Storage, Tensor,
89
};
910
use mistralrs_paged_attn::{copy_blocks, swap_blocks};
11+
use serde::{Deserialize, Serialize};
1012

1113
use super::config::ModelConfigLike;
1214

15+
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
16+
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
17+
pub enum PagedCacheType {
18+
#[default]
19+
Auto,
20+
F8E4M3,
21+
}
22+
23+
impl PagedCacheType {
24+
pub fn to_dtype(&self, act_dtype: DType) -> DType {
25+
match self {
26+
PagedCacheType::F8E4M3 => DType::F8E4M3,
27+
PagedCacheType::Auto => act_dtype,
28+
}
29+
}
30+
}
31+
32+
impl FromStr for PagedCacheType {
33+
type Err = String;
34+
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
35+
match s {
36+
"auto" => Ok(Self::Auto),
37+
"f8e4m3" => Ok(Self::F8E4M3),
38+
other => Err(format!(
39+
"Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
40+
)),
41+
}
42+
}
43+
}
44+
1345
#[derive(Clone, Debug)]
1446
pub struct CacheConfig {
1547
pub block_size: usize,
1648
pub num_gpu_blocks: usize,
1749
pub num_cpu_blocks: usize,
50+
pub cache_type: PagedCacheType,
1851
}
1952

2053
pub type KVCache = (Tensor, Tensor);
@@ -33,6 +66,7 @@ impl CacheEngine {
3366
device: &Device,
3467
layer_devices: Vec<Option<Device>>,
3568
) -> Result<Self> {
69+
let dtype = cache_config.cache_type.to_dtype(dtype);
3670
Ok(Self {
3771
gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
3872
model_config,

0 commit comments

Comments
 (0)