Skip to content

Fix paged-attention KV cache dtype + size accounting (issue #119)#125

Merged
ericcurtin merged 7 commits intovllm-project:mainfrom
LxYuan0420:fix/issue-119-paged-parity
Mar 4, 2026
Merged

Fix paged-attention KV cache dtype + size accounting (issue #119)#125
ericcurtin merged 7 commits intovllm-project:mainfrom
LxYuan0420:fix/issue-119-paged-parity

Conversation

@LxYuan0420
Copy link
Copy Markdown
Collaborator

This PR is:

Notes:

  • tests/test_metal_kernel_paged.py::test_batched_decode_matches now passes.
  • tests/test_metal_kernel_paged.py::test_greedy_output_matches remains xfailed (tracked in Metal paged-attention parity mismatch vs standard path #119). This is a remaining single-request greedy parity mismatch between the paged-kernel path and the standard path; fixing it likely requires deeper kernel/offset semantics work, so I'm keeping it out of this PR to keep scope tight.

Quick manual smoke test:

Terminal 1:

vllm serve Qwen/Qwen3-0.6B --host 127.0.0.1 --port 8000 --max-model-len 2048

Terminal 2 (single request):

curl -fsS http://127.0.0.1:8000/v1/chat/completions \
  -H 'Content-Type: application/json' \
  -d '{"model":"Qwen/Qwen3-0.6B","messages":[{"role":"user","content":"Write a 2-sentence apple story."}],"max_tokens":512,"temperature":0.8}' \
| jq -r '.choices[0].message.content'

Terminal 2 (concurrent 4 requests):

for i in 1 2 3 4; do
  (
    echo "===== req $i ====="
    curl -fsS http://127.0.0.1:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d "{\"model\":\"Qwen/Qwen3-0.6B\",\"messages\":[{\"role\":\"user\",\"content\":\"Write a 2-sentence apple story (${i}).\"}],\"max_tokens\":256,\"temperature\":0.8}" \
    | jq -r '.choices[0].message.content'
    echo
  ) &
done
wait

Related: #119

@LxYuan0420 LxYuan0420 requested a review from ericcurtin March 1, 2026 08:06
@LxYuan0420 LxYuan0420 self-assigned this Mar 1, 2026
@WindChimeRan
Copy link
Copy Markdown
Collaborator

@LxYuan0420
Thanks for tracking this down and fixing it!

The dtype fix looks good to me. The root cause is clear (hardcoded float16 when the model's actual dtype is bfloat16), The kernel already supports all three float types natively, so this was purely a Python-side plumbing issue.

One suggestion: we now have a diagnostic tool in #127 (tools/avg_gen_length.py) that runs offline inference on ShareGPT prompts and reports response length statistics (mean/std). It would be great to run it before and after this fix to quantify the improvement, specifically comparing the paged path (VLLM_METAL_USE_PAGED_ATTENTION=1) against the non-paged baseline. If the distributions align more closely after the fix, that's a strong quantitative signal beyond the existing test assertions.

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
keeps the fallback path backward-compatible with the historical
paged-attention default and avoids silently changing behavior when dtype
inference fails (e.g., unexpected model structure or quantized weights).

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Why: avoids allocating temporary tensors just to compute element size; clearer and cheaper.
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
@LxYuan0420 LxYuan0420 force-pushed the fix/issue-119-paged-parity branch from a75ea44 to 8b578cb Compare March 2, 2026 16:40
@LxYuan0420
Copy link
Copy Markdown
Collaborator Author

Good point. I ran a small end-to-end sanity check with tools/avg_gen_length.py (ShareGPT, fixed seed) and it completes without crashing on my side. For meaningful numbers, could you paste the exact command + summary table you’re using so we can standardize on that as the reference?
@WindChimeRan

@WindChimeRan
Copy link
Copy Markdown
Collaborator

WindChimeRan commented Mar 3, 2026

Findings

  • --max-num-seqs appears to not take effect on the mlx_lm path. In my runs, the script seems to ignore this flag and proceeds with its own batching behavior.
  • The response-length distributions for main-branch mlx_lm and paged KV are currently very similar under this setup.
  • With the current experimental setting, the results are inconclusive as a signal for quality/regression detection.

This may indicate we need a more discriminative setup (e.g., a different dataset, longer decoding lengths, or settings that amplify divergence if it exists).

As a simple short-term check, we could also manually inspect a small sample of responses side by side. If the response quality improves after the fix, or if the paged and non-paged paths look qualitatively similar, that would still be useful supporting evidence.

e.g.,

what's the capital of France?
write a quicksort in python
...

Experiment

Without this patch, on the main branch:

# mlx_lm path
python tools/avg_gen_length.py
============================================================
  max_num_seqs      N   Mean tokens        Std
------------------------------------------------------------
             1    100         244.8       33.4
             8    100         243.6       36.5
============================================================
# paged kv  path
# no batch size 1, taking too long
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.4 python tools/avg_gen_length.py --max-num-seqs 8
============================================================
  max_num_seqs      N   Mean tokens        Std
------------------------------------------------------------
             8    100         243.5       36.6
============================================================

@LxYuan0420
Copy link
Copy Markdown
Collaborator Author

@WindChimeRan Was the "before" run using the same settings like VLLM_METAL_MEMORY_FRACTION=0.4 ?

@LxYuan0420
Copy link
Copy Markdown
Collaborator Author

I think the current result is good enough to show stability under pressure. This PR is a bug fix, not a performance improvement: it fixes the block-exhaustion failure (RuntimeError: Not enough free blocks) by aligning allocation/preemption behavior with scheduler-driven recompute

@ericcurtin PTAL

@ericcurtin
Copy link
Copy Markdown
Collaborator

ericcurtin commented Mar 4, 2026

LGTM, the dtype fix is clean and the root cause (hardcoded float16 vs actual model dtype) is clear. The KvCacheDtypeInference abstraction is well-scoped and the fallback behavior is sensible. Merging.

@ericcurtin ericcurtin merged commit 59b9be4 into vllm-project:main Mar 4, 2026
5 checks passed
LxYuan0420 added a commit that referenced this pull request Mar 11, 2026
This PR is:
- To remove a stale `xfail` on `test_greedy_output_matches` that was
originally added for issue #119.
- To align test expectation with current `main` behavior after
paged-path fixes already merged.
- To keep parity tracking accurate while leaving batched behavior to its
own tracking path.

## Context

Issue #119 reported token mismatch parity failures between:
- standard MLX KV cache path, and
- Metal paged-attention path.

Since then, two key fixes landed:
- #125 corrected paged KV cache dtype inference/fallback behavior and KV
cache size accounting used by paged memory/block calculations.
- #136 replaced the HF/PyTorch kernel-bridge path with native MLX +
inline Metal JIT dispatch (`get_ops`/nanobind), removing cross-framework
bridge behavior from paged execution.

With those changes, the old greedy mismatch from #119 no longer
reproduces on `main`, so the greedy `xfail` is stale.

## Verification

```bash
pytest -q tests/test_metal_kernel_paged.py::TestMetalKernelPagedVsStandard::test_greedy_output_matches -s
pytest -m slow -q tests/test_metal_kernel_paged.py
```

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants