Skip to content

llama : fix KV cache quantization for hybrid Mamba/attention models#1548

Closed
jnovy wants to merge 1 commit intoikawrakow:mainfrom
jnovy:fix-hybrid-kv-cache-quantization
Closed

llama : fix KV cache quantization for hybrid Mamba/attention models#1548
jnovy wants to merge 1 commit intoikawrakow:mainfrom
jnovy:fix-hybrid-kv-cache-quantization

Conversation

@jnovy
Copy link
Copy Markdown

@jnovy jnovy commented Mar 29, 2026

Hybrid Mamba/attention models such as Qwen3.5 and Qwen3-Next crash with NaN logits when using KV cache quantization below q8_0 on the CPU backend. The failure is immediate and catastrophic -- Failed to sample token followed by abort, not a gradual perplexity degradation.

The root cause is that 4/5/6-bit quantization of the KV cache in the attention layers introduces errors that accumulate through the interleaved SSM layers. In Qwen3.5-35B-A3B (40 layers: 10 attention + 30 SSM, full_attention_interval=4), the SSM recurrent state is always F32, but it consumes attention outputs that went through quantized KV cache -- and those errors compound across the attention/SSM cycle.

I initially tried porting the sumqx/sumq2 scale adjustment from PR #1547 to the CPU quantize_row_q4_0_ref() path, hoping better scale selection would be sufficient. It was not - NaN still occurs. PR #1547's scale adjustment is CUDA-only (cpy-utils.cuh) and even the same technique on CPU does not prevent the catastrophic precision loss in hybrid architectures.

This PR auto-upgrades sub-q8_0 quantized KV cache types to q8_0 for hybrid models, with a warning. Non-quantized types (f16, bf16, f32) and q8_0/q8_1/q8_kv pass through unchanged.

Reproducing

llama-cli -m Qwen3.5-35B-A3B-Q4_K_M.gguf -np 1 -fa on -t 16 -ctk q4_0 -ctv q4_0 -khad -vhad -p "Hello world" -n 64
Configuration KV cache Result
main, no fix q4_0 Failed to sample token / NaN crash
scale adjustment only (PR #1547 technique on CPU) q4_0 Failed to sample token / NaN crash
this PR q4_0 requested, q8_0 actual coherent output, warning printed
main, no fix q8_0 coherent output
main, no fix f16 coherent output

Tested on Qwen3.5-35B-A3B (Q4_K_M, 40 layers, 10 attention + 30 SSM, full_attention_interval=4), CPU-only build, AVX2. The q8_0 KV cache works correctly including with Hadamard transforms (-khad -vhad).

Note on scope

The guard applies to all backends, not just CPU. On CUDA, q4_0 KV cache may work for short contexts (PR #1547 shows acceptable PPL at 8k tokens), but NaN is a tail event that PPL averages can mask, and it is likely to surface at longer contexts. Since llama_init_from_model() sets the cache type globally (not per-backend), there is no clean way to scope this to CPU-only without refactoring the KV cache initialization to support per-layer or per-backend type selection.

Hybrid architectures (Qwen3.5, Qwen3-Next) fail with NaN sampling when
using KV cache quantization below q8_0. Low-bit quantization (q4_0,
q4_1, q5_0, q5_1, iq4_nl) causes precision loss that accumulates through
the SSM state or attention-recurrent layer interactions, producing NaN
logits at sampling time.

This patch automatically upgrades q4/q5/q6/iq4 KV cache types to q8_0
for hybrid models, with a warning message. q8_0 has been tested and
works correctly (including with Hadamard transforms).

Tested on Qwen3.5-35B-A3B hybrid Mamba/MoE model (40 layers, 10 attention
+ 30 SSM, full_attention_interval=4).

Signed-off-by: Jindrich Novy <jnovy@redhat.com>
@MrHills-rs
Copy link
Copy Markdown

MrHills-rs commented Mar 29, 2026

Not allowing q4 kv cache is a bit much tho, isn't it?
There are use cases in which long context and high precision are not as relevant, such as classification tasks. Those cases often prioritize a small memory footprint above anything else.

Maybe just post a warning without force upgrading q4 to q8?

Edit:
I'm actually curious, does a higher fa-offset fix it? If so, maybe just add a suggestion to raise the offset whenever using lower precision cache?

@jnovy
Copy link
Copy Markdown
Author

jnovy commented Mar 29, 2026

@MrHills-rs The issue isn't reduced quality - it's a hard crash (Failed to sample token / NaN logits / abort). q4_0 KV cache on hybrid models doesn't degrade gracefully; it produces NaN that kills the process. This happens even on very short prompts (2 tokens: "Hello world" - see the reproducer). A warning-only approach would let users walk into a crash with no workaround except restarting with different parameters. The memory cost of q8_0 vs q4_0 is ~640 MiB for this model (only attention layers use KV cache, SSM state is always F32). Note that this only affects hybrid Mamba/attention architectures -- q4_0 KV cache continues to work as before on standard transformer models.

Regarding fa-offset is a CUDA-only softmax offset parameter AFAICS. The NaN crash reproduces on CPU where fa-offset has no effect. The root cause is the quantization error in the KV cache itself, not numerical instability in the attention softmax.

@MrHills-rs
Copy link
Copy Markdown

MrHills-rs commented Mar 29, 2026

The root cause is the quantization error in the KV cache itself, not numerical instability in the attention softmax.

This is beyond my area of expertise, but aren't the two heavily connected? I thought the first makes it more likely for the second to occur.

Also yes, fa-offset is a cuda parameter. I just thought there was an equivalent parameter for CPU.

At any rate, blocking all q4 kv cache seems overkill. Whenever you're using a classification model, you often do so together with a whole other stack of models, so you need it to be as fast as possible and to be as small as possible to leave more space to the rest of the stack.

Besides, I see a lot of guys using q4 for simple use cases such as RP, and 640MB might not be much for a CPU user but it might be a lot for a GPU user with common Nvidia GPUs such as a 4060 (8gb I think), 4070 (12gb I think), laptop cards, or similar.

Edit: ultimately I think these type of problems should be solved via better documentation rather then forced guardrails. One is expected to tinker a bit when using something like llama.cpp. Run tests. It's mostly backend stuff after all.

@ikawrakow
Copy link
Copy Markdown
Owner

@jnovy I think the issue is that you did not build with -DGGML_IQK_FA_ALL_QUANTS=ON. Without this definition only Q8_0 and Q6_0 are enabled. This is to avoid the relatively long compilation when all quants are enabled for flash attention.

But given that you are not the first to run into this issue, I guess I'll change the default.

@ikawrakow
Copy link
Copy Markdown
Owner

@jnovy If you pull the latest main branch, I'm pretty confident you will not observe NaNs or crashes. Remember that in ik_llama.cpp on the CPU only Q8_0, Q6_0, Q4_1, Q4_0 and IQ4_NL are supported, there is no Q5_0 and Q5_1 support.

@jnovy
Copy link
Copy Markdown
Author

jnovy commented Mar 29, 2026

@ikawrakow confirming no NaN crashes after making the change in the default settings in main. Works for me, thanks!

@jnovy jnovy closed this Mar 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants