feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700
feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700itigges22 wants to merge 1 commit intoggml-org:masterfrom
Conversation
|
As of March 18th 2026, this is just a draft! Still work in progress :) If you have any questions feel free to lmk! |
|
March 19th- making headway, but wow... this is not an easy one. |
Investigation Update: Speculative Framework Crash on DeltaNetAfter enabling MTP speculative decoding (bypassing the seq_rm compat check), the server initializes correctly but segfaults during the first speculative draft/verify cycle on DeltaNet hybrid models. What works:
What crashes:
Root cause hypothesis: Added fallback (in latest push):
Next steps needed:
Tested on: RTX 5060 Ti 16GB, Linux, Qwen3.5-9B-MTP-Q4_K_M.gguf, |
Breakthrough: 95% MTP acceptance rate with cooldown fixRoot cause found and fixed. After draft rejection, MTP logits are read from the DRAFT token's position (last in the [sampled, draft] batch). These logits predict what comes after the rejected draft — which is wrong. The next proposal uses these stale logits, producing a cascade of bad drafts (13% acceptance rate → garbled output). Fix: Added cooldown flag in Results:
The remaining speed being similar to non-MTP (16.7 vs 16.7) is because cooldown means proposals happen every OTHER step. The theoretical max with cooldown is ~1.33x (not 2x). Removing cooldown for ACCEPTED drafts (only cooldown on rejection) should increase the effective speedup. Output quality is almost correct — minor degradation in docstrings ( |
fef9ada to
affba2a
Compare
Status Update — MTP Attention + FastMTPMajor rework since last update. Squashed all intermediate commits into a single clean commit. What changed:
Key finding: DeltaNet + speculative decode is fundamentally hardThe recurrent state in DeltaNet accumulates all previous tokens. Unlike KV cache, you can't The two-phase decode approach (decode sampled → verify → decode draft only if accepted) produces correct output but halves throughput since each accepted step requires 2 decode calls. Open questions for reviewers:
Happy to split this into smaller PRs if that helps review (e.g., separate the recurrent state fixes from the MTP graph builder). |
Add Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). The MTP head uses a full transformer block (attention + FFN) to predict the next-next token, enabling ~28 tok/s on RTX 5060 Ti. Key changes: - Model loading: Qwen3.5 MTP layer tensors (nextn.eh_proj, attention weights, FFN) loaded into layers[n_layer-1] - Graph builder: Full MTP head with self-attention, gated RoPE, FFN, and vocabulary projection. Unfiltered hidden state passed for proper KV cache population during prompt processing. - FastMTP: Vocabulary trimming from 248K to 32K tokens via ggml_view_2d on the lm_head. Reduces draft generation from 22ms to 6ms (3.7x). - Speculative framework: MTP auto-detection for hybrid models, fuzzy seq_rm checkpoint matching for DeltaNet rollback. - Server: Two-phase decode option for hybrid/recurrent models to avoid DeltaNet state corruption from rejected drafts. - Recurrent state: Fixed copy_cell (ggml_view_1d takes element count, not bytes), buffer assignment for no_alloc views. Results on Qwen3.5-9B Q4_K_M (RTX 5060 Ti 16GB): - 28.1 tok/s with 82% acceptance rate (temp=0) - 92% acceptance with two-phase decode (correct output, 15 tok/s) - Draft generation: 6.1ms with FastMTP (vs 22.4ms full vocab)
affba2a to
19fdba5
Compare
Relationship to #18886 (MTP API) and #15225 (GLM MTP)This PR takes a different approach from @ngxson's MTP API design in #18886 — we use a single context with the MTP head inline in the main compute graph rather than separate contexts. This was necessary because Qwen3.5's hybrid DeltaNet architecture makes state copying between contexts complex (the recurrent state can't be trivially transferred). However, the core contributions here are architecture-independent and would benefit any MTP implementation for hybrid/recurrent models:
Happy to refactor to align with the #18886 API once it stabilizes. The DeltaNet-specific handling would need to be integrated regardless of API design — it's a fundamental requirement for speculative decoding on any recurrent/hybrid architecture (Mamba, RWKV, DeltaNet, etc.). If it helps, I can split the recurrent state fixes into a separate PR that's useful independent of MTP. |
|
Please note: This PR is still very much a WIP- I do not expect it to be merged any time soon. However, what I do hope is that it provides some detail into the direction llama.cpp should be going in terms of adding MTP support. Since I do not have the resources to fully test. I have no doubt that there are bugs, it needs to be refactored, etc... What I sincerely ask of the reviewers and the community to do is to take a look at the work done here, and see if you are able to find a better hopefully more suitable solution. This work is motivated by ATLAS — MTP support would enable meaningful speedups for local setups running the Qwen3.5 family of models. Best, Isaac :) |
|
Hi @itigges22, thanks for this PR — I've been testing it on Mac M4 (Metal) and RTX 3090 (CUDA) with Qwen3.5-9B Q4_K_M. The MTP implementation works well — no crashes, deterministic output, correct multi-turn behavior. However, I'm seeing a 63.5% draft acceptance rate at temperature=0, compared to the 82% reported in the PR description. My measurement (consistent across multiple requests): To help reproduce your 82% number, could you share:
My setup:
Thanks! |
|
Actually, re-reading the PR more carefully I see:
So you used mixed precision — Q4_K_M for the base model with F16 for the MTP head weights. My conversion quantized everything uniformly to Q4_K_M (including the MTP head), which would explain the lower acceptance rate. Could you share how you produced the mixed-precision GGUF? Specifically:
This would help reproduce the 82% number and would also be useful context for anyone else testing MTP. |
|
I tried compiling this version with success on a NVIDIA RTX A4500 Card. |
@petter-b Apologies for the delay! To answer your question- I converted the model myself from the raw HuggingFace weights using the standard convert_hf_to_gguf.py script that comes with llama.cpp. This script reads the HF model files and outputs a GGUF, and I then ran llama-quantize on it to get Q4_K_M. No special steps were taken to preserve the MTP head in F16- the entire model including the MTP layers was quantized uniformly to Q4_K_M. The f32 tensors you'd see in the GGUF are just norm/embedding layers that llama-quantize leaves in f32 by default. I should also mention that after more testing, the acceptance rate has been inconsistent — I've seen it range anywhere from 40% to 82%, and replicating the higher numbers consistently has proven difficult. It's been one of the harder things to track down on my end, so I wouldn't treat the 82% as a reliable baseline just yet. However, I do not have the compute to attempt any of this in FP16! So- my methodology and approach was grounded in the limited compute that I have! |
Summary
Adds Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). These hybrid DeltaNet/attention models have a built-in MTP head that predicts the next-next token, enabling speculative decoding without a separate draft model.
Key features:
copy_cell(element count vs byte count inggml_view_1d), fuzzyseq_rmcheckpoint matchingResults (Qwen3.5-9B Q4_K_M, RTX 5060 Ti 16GB VRAM):
With two-phase decode (guaranteed correct output):
Architecture
Qwen3.5 uses a repeating pattern of 3 DeltaNet (linear attention) + 1 full attention layers. The MTP head is a single full-attention transformer block that:
eh_projThe DeltaNet recurrent state cannot be partially rolled back (unlike KV cache), so rejected drafts corrupt the state. The two-phase decode option handles this by only decoding accepted drafts.
Files changed:
src/models/qwen35.cpp— MTP head graph builder with attention + FastMTPsrc/llama-memory-recurrent.cpp— copy_cell fix, seq_rm fuzzy matchingsrc/llama-model.cpp— MTP tensor loading, rs_size configsrc/llama-context.cpp/h— MTP logits extraction, reduced vocab trackingcommon/speculative.cpp— MTP state machine, FastMTP vocab supporttools/server/server-context.cpp— Two-phase decode for hybrid modelsinclude/llama.h—llama_get_mtp_n_vocab()APIconvert_hf_to_gguf.py— Qwen3.5 MTP tensor conversion supportTesting
Tested on:
--jinja --embeddings --parallel 1What's needed for merge: