Skip to content

kv-cache : do not quantize SWA KV cache#21277

Merged
ggerganov merged 1 commit intomasterfrom
gg/kv-cache-no-swa-quant
Apr 2, 2026
Merged

kv-cache : do not quantize SWA KV cache#21277
ggerganov merged 1 commit intomasterfrom
gg/kv-cache-no-swa-quant

Conversation

@ggerganov
Copy link
Copy Markdown
Member

Overview

cont #21038

We don't need to quantize the SWA part of the cache for iSWA models because it is relatively small. So keep it in F16.

Requirements

@ggerganov ggerganov merged commit 17193cc into master Apr 2, 2026
45 of 46 checks passed
@ggerganov ggerganov deleted the gg/kv-cache-no-swa-quant branch April 2, 2026 08:54
@NovNovikov
Copy link
Copy Markdown

Gemma 4 SWA KV cache is hulariously large, so this commit really made it unusable.

@ggerganov
Copy link
Copy Markdown
Member Author

Gemma 4 SWA KV cache is hulariously large, so this commit really made it unusable.

The SWA cache of Gemma 4 26B A4B is 300 MiB in F16.

image

@NovNovikov
Copy link
Copy Markdown

NovNovikov commented Apr 2, 2026

It is 4000 MiB for Gemma 4-31B for context of 8192
image

@ddh0
Copy link
Copy Markdown
Contributor

ddh0 commented Apr 2, 2026

5120 cells

Hmm, shouldn't it only be using 1024 for the sliding window?

@NovNovikov
Copy link
Copy Markdown

Well looks like it comes from -ub 4096 -b 4096
When returning to default it lowers the number down to 1200 MiB for context of 8192, which is still twice as much as the size of non-swa cache, which is the only one which is being quantized now.
image

@aviallon
Copy link
Copy Markdown
Contributor

aviallon commented Apr 2, 2026

@ggerganov would it be feasible to have yet another flag to control how the SWA cache is quantized?
I may want to use q8 for the SWA cache, and q4 for the rest, for instance.
I may also want to use f32.

On my CDNA2 card, I'm developing an optimized q8 KV path (compute small batches directly without dequantizing), as it is much faster on this card, for many reasons.

@ggerganov
Copy link
Copy Markdown
Member Author

I think it's best to revert the change from this PR. Will do it now.

@ddh0
Copy link
Copy Markdown
Contributor

ddh0 commented Apr 3, 2026

I still think this might be worth taking a look at - the sliding window size n_swa is correctly 1024, but I'm seeing it allocate in my case 2048 cells for the SWA portion of the KV cache:

0.00.673.378 I print_info: arch                  = gemma4
0.00.673.391 I print_info: n_swa                 = 1024
0.00.673.407 I print_info: n_embd_head_k_swa     = 256
0.00.673.407 I print_info: n_embd_head_v_swa     = 256
0.00.673.407 I print_info: n_rot_swa             = 256
0.17.396.495 I llama_context: n_seq_max     = 1
0.17.396.495 I llama_context: n_ctx         = 32768
0.17.396.495 I llama_context: n_ctx_seq     = 32768
0.17.396.495 I llama_context: n_batch       = 1024
0.17.396.496 I llama_context: n_ubatch      = 1024
0.17.397.462 I llama_kv_cache_iswa: creating non-SWA KV cache, size = 32768 cells
0.17.397.716 I llama_kv_cache:      CUDA0 KV buffer size =  2560.00 MiB
0.17.407.555 I llama_kv_cache: size = 2560.00 MiB ( 32768 cells,  10 layers,  1/1 seqs), K (f16): 1280.00 MiB, V (f16): 1280.00 MiB
0.17.407.556 I llama_kv_cache: attn_rot_k = 0
0.17.407.556 I llama_kv_cache: attn_rot_v = 0
0.17.407.557 I llama_kv_cache_iswa: creating     SWA KV cache, size = 2048 cells
0.17.407.693 I llama_kv_cache:      CUDA0 KV buffer size =  1600.00 MiB
0.17.413.833 I llama_kv_cache: size = 1600.00 MiB (  2048 cells,  50 layers,  1/1 seqs), K (f16):  800.00 MiB, V (f16):  800.00 MiB
0.17.413.834 I llama_kv_cache: attn_rot_k = 0
0.17.413.834 I llama_kv_cache: attn_rot_v = 0
0.17.413.837 I sched_reserve: reserving ...
0.17.414.562 I sched_reserve: Flash Attention was auto, set to enabled
0.17.414.563 I sched_reserve: resolving fused Gated Delta Net support:
0.17.414.937 I sched_reserve: fused Gated Delta Net (autoregressive) enabled
0.17.415.273 I sched_reserve: fused Gated Delta Net (chunked) enabled
0.17.478.123 I sched_reserve:      CUDA0 compute buffer size =  1045.00 MiB
0.17.478.127 I sched_reserve:  CUDA_Host compute buffer size =   178.04 MiB
0.17.478.128 I sched_reserve: graph nodes  = 2462
0.17.478.128 I sched_reserve: graph splits = 182 (with bs=1024), 122 (with bs=1)
0.17.478.129 I sched_reserve: reserve took 64.29 ms, sched copies = 1

Unless there is a reason for this? Maybe I misunderstand something, but I thought I would mention it here.

@ggerganov
Copy link
Copy Markdown
Member Author

This is expected, we need extra space beyond the n_swa. Here is the exact formula:

// note: the SWA cache is always padded to 256 for performance
// https://github.com/ggml-org/llama.cpp/issues/17037
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);

luvwinnie added a commit to luvwinnie/llama.cpp that referenced this pull request Apr 4, 2026
Merge 59 upstream commits including:
- model: support gemma 4 (vision + moe, no audio) (ggml-org#21309)
- kv-cache: do not quantize SWA KV cache (ggml-org#21277)
- Preserve RotorQuant exclusion from Hadamard rotation
icex added a commit to icex/llama.cpp that referenced this pull request Apr 5, 2026
Includes:
- server: Fix undefined timing measurement errors (ggml-org#21201)
- server: save and clear idle slots on new task --clear-idle (ggml-org#20993)
- common: fix tool call type detection for nullable/enum schemas (ggml-org#21327)
- CUDA: fix FA kernel selection logic (ggml-org#21271)
- kv-cache: do not quantize SWA KV cache (ggml-org#21277) + revert (ggml-org#21332)
- common/parser: fix call ID detection + atomicity (ggml-org#21230)
- jinja: coerce input for string-specific filters (ggml-org#21370)
- Various CI, HIP, WebGPU, and documentation fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants