llama : fix KV cache quantization for hybrid Mamba/attention models#1548
llama : fix KV cache quantization for hybrid Mamba/attention models#1548jnovy wants to merge 1 commit intoikawrakow:mainfrom
Conversation
Hybrid architectures (Qwen3.5, Qwen3-Next) fail with NaN sampling when using KV cache quantization below q8_0. Low-bit quantization (q4_0, q4_1, q5_0, q5_1, iq4_nl) causes precision loss that accumulates through the SSM state or attention-recurrent layer interactions, producing NaN logits at sampling time. This patch automatically upgrades q4/q5/q6/iq4 KV cache types to q8_0 for hybrid models, with a warning message. q8_0 has been tested and works correctly (including with Hadamard transforms). Tested on Qwen3.5-35B-A3B hybrid Mamba/MoE model (40 layers, 10 attention + 30 SSM, full_attention_interval=4). Signed-off-by: Jindrich Novy <jnovy@redhat.com>
|
Not allowing q4 kv cache is a bit much tho, isn't it? Maybe just post a warning without force upgrading q4 to q8? Edit: |
|
@MrHills-rs The issue isn't reduced quality - it's a hard crash ( Regarding |
This is beyond my area of expertise, but aren't the two heavily connected? I thought the first makes it more likely for the second to occur. Also yes, fa-offset is a cuda parameter. I just thought there was an equivalent parameter for CPU. At any rate, blocking all q4 kv cache seems overkill. Whenever you're using a classification model, you often do so together with a whole other stack of models, so you need it to be as fast as possible and to be as small as possible to leave more space to the rest of the stack. Besides, I see a lot of guys using q4 for simple use cases such as RP, and 640MB might not be much for a CPU user but it might be a lot for a GPU user with common Nvidia GPUs such as a 4060 (8gb I think), 4070 (12gb I think), laptop cards, or similar. Edit: ultimately I think these type of problems should be solved via better documentation rather then forced guardrails. One is expected to tinker a bit when using something like llama.cpp. Run tests. It's mostly backend stuff after all. |
|
@jnovy I think the issue is that you did not build with But given that you are not the first to run into this issue, I guess I'll change the default. |
|
@jnovy If you pull the latest main branch, I'm pretty confident you will not observe NaNs or crashes. Remember that in |
|
@ikawrakow confirming no NaN crashes after making the change in the default settings in main. Works for me, thanks! |
Hybrid Mamba/attention models such as Qwen3.5 and Qwen3-Next crash with NaN logits when using KV cache quantization below
q8_0on the CPU backend. The failure is immediate and catastrophic --Failed to sample tokenfollowed by abort, not a gradual perplexity degradation.The root cause is that 4/5/6-bit quantization of the KV cache in the attention layers introduces errors that accumulate through the interleaved SSM layers. In Qwen3.5-35B-A3B (40 layers: 10 attention + 30 SSM,
full_attention_interval=4), the SSM recurrent state is always F32, but it consumes attention outputs that went through quantized KV cache -- and those errors compound across the attention/SSM cycle.I initially tried porting the
sumqx/sumq2scale adjustment from PR #1547 to the CPUquantize_row_q4_0_ref()path, hoping better scale selection would be sufficient. It was not - NaN still occurs. PR #1547's scale adjustment is CUDA-only (cpy-utils.cuh) and even the same technique on CPU does not prevent the catastrophic precision loss in hybrid architectures.This PR auto-upgrades sub-
q8_0quantized KV cache types toq8_0for hybrid models, with a warning. Non-quantized types (f16,bf16,f32) andq8_0/q8_1/q8_kvpass through unchanged.Reproducing
Failed to sample token/ NaN crashFailed to sample token/ NaN crashTested on Qwen3.5-35B-A3B (Q4_K_M, 40 layers, 10 attention + 30 SSM,
full_attention_interval=4), CPU-only build, AVX2. Theq8_0KV cache works correctly including with Hadamard transforms (-khad -vhad).Note on scope
The guard applies to all backends, not just CPU. On CUDA,
q4_0KV cache may work for short contexts (PR #1547 shows acceptable PPL at 8k tokens), but NaN is a tail event that PPL averages can mask, and it is likely to surface at longer contexts. Sincellama_init_from_model()sets the cache type globally (not per-backend), there is no clean way to scope this to CPU-only without refactoring the KV cache initialization to support per-layer or per-backend type selection.