Fix paged-attention KV cache dtype + size accounting (issue #119)#125
Conversation
|
@LxYuan0420 The dtype fix looks good to me. The root cause is clear (hardcoded float16 when the model's actual dtype is bfloat16), The kernel already supports all three float types natively, so this was purely a Python-side plumbing issue. One suggestion: we now have a diagnostic tool in #127 ( |
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
keeps the fallback path backward-compatible with the historical paged-attention default and avoids silently changing behavior when dtype inference fails (e.g., unexpected model structure or quantized weights). Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Why: avoids allocating temporary tensors just to compute element size; clearer and cheaper. Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
a75ea44 to
8b578cb
Compare
|
Good point. I ran a small end-to-end sanity check with tools/avg_gen_length.py (ShareGPT, fixed seed) and it completes without crashing on my side. For meaningful numbers, could you paste the exact command + summary table you’re using so we can standardize on that as the reference? |
Findings
This may indicate we need a more discriminative setup (e.g., a different dataset, longer decoding lengths, or settings that amplify divergence if it exists). As a simple short-term check, we could also manually inspect a small sample of responses side by side. If the response quality improves after the fix, or if the paged and non-paged paths look qualitatively similar, that would still be useful supporting evidence. e.g., ExperimentWithout this patch, on the main branch: # mlx_lm path
python tools/avg_gen_length.py# paged kv path
# no batch size 1, taking too long
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.4 python tools/avg_gen_length.py --max-num-seqs 8 |
|
@WindChimeRan Was the "before" run using the same settings like |
|
I think the current result is good enough to show stability under pressure. This PR is a bug fix, not a performance improvement: it fixes the block-exhaustion failure (RuntimeError: Not enough free blocks) by aligning allocation/preemption behavior with scheduler-driven recompute @ericcurtin PTAL |
|
LGTM, the dtype fix is clean and the root cause (hardcoded float16 vs actual model dtype) is clear. The KvCacheDtypeInference abstraction is well-scoped and the fallback behavior is sensible. Merging. |
This PR is: - To remove a stale `xfail` on `test_greedy_output_matches` that was originally added for issue #119. - To align test expectation with current `main` behavior after paged-path fixes already merged. - To keep parity tracking accurate while leaving batched behavior to its own tracking path. ## Context Issue #119 reported token mismatch parity failures between: - standard MLX KV cache path, and - Metal paged-attention path. Since then, two key fixes landed: - #125 corrected paged KV cache dtype inference/fallback behavior and KV cache size accounting used by paged memory/block calculations. - #136 replaced the HF/PyTorch kernel-bridge path with native MLX + inline Metal JIT dispatch (`get_ops`/nanobind), removing cross-framework bridge behavior from paged execution. With those changes, the old greedy mismatch from #119 no longer reproduces on `main`, so the greedy `xfail` is stale. ## Verification ```bash pytest -q tests/test_metal_kernel_paged.py::TestMetalKernelPagedVsStandard::test_greedy_output_matches -s pytest -m slow -q tests/test_metal_kernel_paged.py ``` Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
This PR is:
torch.dtype.itemsizeinstead of allocating temporary tensors.Notes:
tests/test_metal_kernel_paged.py::test_batched_decode_matchesnow passes.tests/test_metal_kernel_paged.py::test_greedy_output_matchesremains xfailed (tracked in Metal paged-attention parity mismatch vs standard path #119). This is a remaining single-request greedy parity mismatch between the paged-kernel path and the standard path; fixing it likely requires deeper kernel/offset semantics work, so I'm keeping it out of this PR to keep scope tight.Quick manual smoke test:
Terminal 1:
Terminal 2 (single request):
Terminal 2 (concurrent 4 requests):
Related: #119