Skip to content

feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700

Open
itigges22 wants to merge 1 commit intoggml-org:masterfrom
itigges22:feat/qwen35-dense-mtp
Open

feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming#20700
itigges22 wants to merge 1 commit intoggml-org:masterfrom
itigges22:feat/qwen35-dense-mtp

Conversation

@itigges22
Copy link

@itigges22 itigges22 commented Mar 17, 2026

Summary

Adds Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense models (0.8B-27B). These hybrid DeltaNet/attention models have a built-in MTP head that predicts the next-next token, enabling speculative decoding without a separate draft model.

Key features:

  • Full MTP attention: The MTP head includes self-attention with its own KV cache (not just FFN), matching the architecture described in the Qwen3.5 technical report
  • FastMTP vocabulary trimming: Reduces MTP lm_head from 248K → 32K tokens (3.7x faster draft generation, ~0% quality loss for code/text)
  • Hybrid model support: Two-phase decode option for DeltaNet recurrent state — avoids state corruption from rejected drafts
  • Recurrent state fixes: Correct copy_cell (element count vs byte count in ggml_view_1d), fuzzy seq_rm checkpoint matching

Results (Qwen3.5-9B Q4_K_M, RTX 5060 Ti 16GB VRAM):

Metric Without MTP With MTP + FastMTP
Speed 30 tok/s 28.1 tok/s
Acceptance rate 82% (temp=0)
Draft gen time 6.1ms
Output quality Clean Clean (short gen)

With two-phase decode (guaranteed correct output):

  • 92% acceptance, 15 tok/s
  • Clean output on all generation lengths

Architecture

Qwen3.5 uses a repeating pattern of 3 DeltaNet (linear attention) + 1 full attention layers. The MTP head is a single full-attention transformer block that:

  1. Combines token embedding + hidden state via eh_proj
  2. Runs self-attention with gated RoPE (same format as main model)
  3. Runs FFN
  4. Projects to vocabulary logits (trimmed to 32K with FastMTP)

The DeltaNet recurrent state cannot be partially rolled back (unlike KV cache), so rejected drafts corrupt the state. The two-phase decode option handles this by only decoding accepted drafts.

Files changed:

  • src/models/qwen35.cpp — MTP head graph builder with attention + FastMTP
  • src/llama-memory-recurrent.cpp — copy_cell fix, seq_rm fuzzy matching
  • src/llama-model.cpp — MTP tensor loading, rs_size config
  • src/llama-context.cpp/h — MTP logits extraction, reduced vocab tracking
  • common/speculative.cpp — MTP state machine, FastMTP vocab support
  • tools/server/server-context.cpp — Two-phase decode for hybrid models
  • include/llama.hllama_get_mtp_n_vocab() API
  • convert_hf_to_gguf.py — Qwen3.5 MTP tensor conversion support

Testing

Tested on:

  • Qwen3.5-9B (Q4_K_M + F16 MTP weights)
  • RTX 5060 Ti 16GB VRAM
  • K3s deployment with --jinja --embeddings --parallel 1

What's needed for merge:

  • Review of recurrent state handling (copy_cell, seq_rm changes)
  • Review of MTP attention graph builder
  • CI tests passing (need Qwen3.5 MTP GGUF test model)
  • Consider making FastMTP vocab size configurable (currently hardcoded 32K)
  • Consider server flag to toggle two-phase decode for hybrid models

Please note: This PR is still very much a WIP — I do not expect it to be merged any time soon. However, what I do hope is that it provides some detail into the direction llama.cpp should be going in terms of adding MTP support.

Since I do not have the resources to fully test, I have no doubt that there are bugs and it needs to be refactored.

What I sincerely ask of the reviewers and the community is to take a look at the work done here, and see if you are able to find a better, hopefully more suitable solution.

This work is motivated by ATLAS — MTP support would enable meaningful speedups for local setups running the Qwen3.5 family of models.

Best, Isaac :)

@github-actions github-actions bot added model Model specific examples python python script changes server labels Mar 17, 2026
@itigges22
Copy link
Author

As of March 18th 2026, this is just a draft! Still work in progress :)

If you have any questions feel free to lmk!

@itigges22
Copy link
Author

March 19th- making headway, but wow... this is not an easy one.

@itigges22
Copy link
Author

Investigation Update: Speculative Framework Crash on DeltaNet

After enabling MTP speculative decoding (bypassing the seq_rm compat check), the server initializes correctly but segfaults during the first speculative draft/verify cycle on DeltaNet hybrid models.

What works:

  • Compat fix passes (checkpoint/restore detection)
  • Auto-enable sets COMMON_SPECULATIVE_TYPE_MTP when MTP layers detected
  • Speculative context initializes (speculative decoding context initialized)
  • Model loads and processes prompts normally

What crashes:

  • First call to common_speculative_draft() or subsequent verify loop
  • Segfault (no error message, process death)
  • Happens after prompt processing completes, during first token generation with speculative active

Root cause hypothesis:
The speculative framework's common_speculative_draft() likely calls functions that assume standard transformer KV cache behavior. The DeltaNet recurrent memory has a different API surface that causes null pointer dereference or out-of-bounds access when the speculative state machine tries to manipulate the sequence.

Added fallback (in latest push):

  • seq_rm fallback in the verify loop: when seq_rm returns false, re-evaluates accepted tokens to rebuild correct recurrent state
  • This doesn't fix the crash since the crash happens before reaching the verify loop

Next steps needed:

  1. Debug the segfault in common_speculative_draft() — likely a DeltaNet-incompatible call in the draft proposal path
  2. The common_speculative_init() succeeds but may set up state that references transformer-only APIs
  3. May need DeltaNet-specific draft proposal that uses llama_get_mtp_logits() directly instead of going through the speculative state machine

Tested on: RTX 5060 Ti 16GB, Linux, Qwen3.5-9B-MTP-Q4_K_M.gguf, --parallel 1 --jinja --no-cache-prompt

@itigges22
Copy link
Author

Breakthrough: 95% MTP acceptance rate with cooldown fix

Root cause found and fixed. After draft rejection, MTP logits are read from the DRAFT token's position (last in the [sampled, draft] batch). These logits predict what comes after the rejected draft — which is wrong. The next proposal uses these stale logits, producing a cascade of bad drafts (13% acceptance rate → garbled output).

Fix: Added cooldown flag in common_speculative_state_mtp::accept(). When no drafts are accepted (n_accepted == 0), the next draft() call returns empty, forcing a single-token decode. This produces fresh MTP logits from the correct position. Next proposal uses these fresh logits.

Results:

Metric Before cooldown After cooldown
Acceptance rate 13% (87% wrong) 95%
Output quality Garbled Clean
Server crashes Yes (seq_rm) 0 restarts
Speed 16.4 tok/s 16.7 tok/s

The remaining speed being similar to non-MTP (16.7 vs 16.7) is because cooldown means proposals happen every OTHER step. The theoretical max with cooldown is ~1.33x (not 2x). Removing cooldown for ACCEPTED drafts (only cooldown on rejection) should increase the effective speedup.

Output quality is almost correct — minor degradation in docstrings (List the. truncation) likely from DeltaNet recurrent state batch processing difference. Investigating.

@itigges22 itigges22 force-pushed the feat/qwen35-dense-mtp branch from fef9ada to affba2a Compare March 21, 2026 17:51
@itigges22 itigges22 changed the title feat: MTP support for dense Qwen 3.5 (0.8B-27B) feat: MTP support for dense Qwen 3.5 with FastMTP vocabulary trimming Mar 21, 2026
@itigges22
Copy link
Author

Status Update — MTP Attention + FastMTP

Major rework since last update. Squashed all intermediate commits into a single clean commit.

What changed:

  1. Implemented full MTP attention — The MTP head was previously FFN-only (no self-attention). vLLM uses layer_type="full_attention" for MTP, and adding attention improved acceptance from 60% → 82% at temp=0. The MTP layer's KV cache is already allocated by the hybrid memory system (is_recurrent(mtp_layer) = false).

  2. FastMTP vocabulary trimming — The MTP lm_head projection (4096→248K) was the biggest compute bottleneck (~10ms per decode). By using ggml_view_2d to trim to top 32K tokens, draft generation dropped from 22ms → 6ms with no measurable accuracy loss (code tokens are within the top 32K of most tokenizers).

  3. Fixed copy_cell for recurrent stateggml_view_1d takes element count, not byte count. The old code passed byte counts, causing data_size + view_offs > ggml_nbytes(view_src) assertions. Also fixed buffer assignment for views created with no_alloc=true.

  4. Fuzzy seq_rm checkpoint matching — For DeltaNet with 2-token batches, checkpoints are 2 positions behind the rejection point. Changed from exact pos == p0 - 1 to closest pos < p0 matching.

Key finding: DeltaNet + speculative decode is fundamentally hard

The recurrent state in DeltaNet accumulates all previous tokens. Unlike KV cache, you can't seq_rm a single position — the state must be checkpointed before speculation and restored on rejection. This adds overhead and limits the net speedup from MTP.

The two-phase decode approach (decode sampled → verify → decode draft only if accepted) produces correct output but halves throughput since each accepted step requires 2 decode calls.

Open questions for reviewers:

  1. FastMTP vocab size: Currently hardcoded to 32768. Should this be configurable via llama_model_params or a server flag?
  2. Two-phase decode: Should this be the default for hybrid models, or opt-in via a flag?
  3. Unfiltered hidden state: The MTP head needs the unfiltered (pre-inp_out_ids) hidden state for attention KV cache population during prompt processing. Currently stored in a class member (mtp_inp_hidden). Is there a cleaner approach?

Happy to split this into smaller PRs if that helps review (e.g., separate the recurrent state fixes from the MTP graph builder).

Add Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 dense
models (0.8B-27B). The MTP head uses a full transformer block (attention
+ FFN) to predict the next-next token, enabling ~28 tok/s on RTX 5060 Ti.

Key changes:
- Model loading: Qwen3.5 MTP layer tensors (nextn.eh_proj, attention
  weights, FFN) loaded into layers[n_layer-1]
- Graph builder: Full MTP head with self-attention, gated RoPE, FFN,
  and vocabulary projection. Unfiltered hidden state passed for proper
  KV cache population during prompt processing.
- FastMTP: Vocabulary trimming from 248K to 32K tokens via ggml_view_2d
  on the lm_head. Reduces draft generation from 22ms to 6ms (3.7x).
- Speculative framework: MTP auto-detection for hybrid models, fuzzy
  seq_rm checkpoint matching for DeltaNet rollback.
- Server: Two-phase decode option for hybrid/recurrent models to avoid
  DeltaNet state corruption from rejected drafts.
- Recurrent state: Fixed copy_cell (ggml_view_1d takes element count,
  not bytes), buffer assignment for no_alloc views.

Results on Qwen3.5-9B Q4_K_M (RTX 5060 Ti 16GB):
- 28.1 tok/s with 82% acceptance rate (temp=0)
- 92% acceptance with two-phase decode (correct output, 15 tok/s)
- Draft generation: 6.1ms with FastMTP (vs 22.4ms full vocab)
@itigges22 itigges22 force-pushed the feat/qwen35-dense-mtp branch from affba2a to 19fdba5 Compare March 21, 2026 18:19
@itigges22
Copy link
Author

Relationship to #18886 (MTP API) and #15225 (GLM MTP)

This PR takes a different approach from @ngxson's MTP API design in #18886 — we use a single context with the MTP head inline in the main compute graph rather than separate contexts. This was necessary because Qwen3.5's hybrid DeltaNet architecture makes state copying between contexts complex (the recurrent state can't be trivially transferred).

However, the core contributions here are architecture-independent and would benefit any MTP implementation for hybrid/recurrent models:

  1. copy_cell fixggml_view_1d takes element count, not bytes. This bug affects all recurrent state checkpoint/restore operations.
  2. Fuzzy seq_rm checkpoint matching — needed because 2-token speculative batches create checkpoints 2 positions behind the rejection point.
  3. Two-phase decode for hybrid models — the only way to avoid DeltaNet state corruption from rejected drafts.
  4. FastMTP vocabulary trimmingggml_view_2d on lm_head to reduce from 248K→32K tokens (3.7x faster draft generation).

Happy to refactor to align with the #18886 API once it stabilizes. The DeltaNet-specific handling would need to be integrated regardless of API design — it's a fundamental requirement for speculative decoding on any recurrent/hybrid architecture (Mamba, RWKV, DeltaNet, etc.).

If it helps, I can split the recurrent state fixes into a separate PR that's useful independent of MTP.

@itigges22 itigges22 marked this pull request as ready for review March 21, 2026 18:29
@itigges22 itigges22 requested review from a team, CISC and ggerganov as code owners March 21, 2026 18:29
@itigges22
Copy link
Author

itigges22 commented Mar 22, 2026

Please note: This PR is still very much a WIP- I do not expect it to be merged any time soon. However, what I do hope is that it provides some detail into the direction llama.cpp should be going in terms of adding MTP support.

Since I do not have the resources to fully test. I have no doubt that there are bugs, it needs to be refactored, etc...

What I sincerely ask of the reviewers and the community to do is to take a look at the work done here, and see if you are able to find a better hopefully more suitable solution.

This work is motivated by ATLAS — MTP support would enable meaningful speedups for local setups running the Qwen3.5 family of models.

Best, Isaac :)

@petter-b
Copy link

Hi @itigges22, thanks for this PR — I've been testing it on Mac M4 (Metal) and RTX 3090 (CUDA) with Qwen3.5-9B Q4_K_M.

The MTP implementation works well — no crashes, deterministic output, correct multi-turn behavior. However, I'm seeing a 63.5% draft acceptance rate at temperature=0, compared to the 82% reported in the PR description.

My measurement (consistent across multiple requests):

draft acceptance rate = 0.63462 (99 accepted / 156 generated)

To help reproduce your 82% number, could you share:

  1. MTP weight precision: Did you use F16 MTP head weights with Q4_K_M base (mixed precision), or was the entire model Q4_K_M including MTP tensors?
  2. Which prompt produced the 82% acceptance rate?
  3. Server flags — specifically --draft-max, --ctx-size, and any other flags beyond --spec-type mtp?

My setup:

  • Model: converted from Qwen/Qwen3.5-9B via convert_hf_to_gguf.py (BF16), then llama-quantize to Q4_K_M (all tensors quantized, including MTP head)
  • Flags: --spec-type mtp --draft-max 1 --flash-attn on -np 1 -c 4096 -ngl 99
  • Prompt: "Write a Python function that implements quicksort with detailed comments explaining each step." (n_predict=256, temperature=0)

Thanks!

@petter-b
Copy link

Actually, re-reading the PR more carefully I see:

Tested on: Qwen3.5-9B (Q4_K_M + F16 MTP weights)

So you used mixed precision — Q4_K_M for the base model with F16 for the MTP head weights. My conversion quantized everything uniformly to Q4_K_M (including the MTP head), which would explain the lower acceptance rate.

Could you share how you produced the mixed-precision GGUF? Specifically:

  • Did you modify convert_hf_to_gguf.py to keep MTP tensors in F16 while quantizing the rest?
  • Or did you quantize with llama-quantize using a flag/config to exclude MTP tensors?

This would help reproduce the 82% number and would also be useful context for anyone else testing MTP.

@jinkang06
Copy link

I tried compiling this version with success on a NVIDIA RTX A4500 Card.
The Qwen3.5 27b increases its performance by 12%, from 25t/s to 28 t/s.

@itigges22
Copy link
Author

itigges22 commented Mar 26, 2026

Actually, re-reading the PR more carefully I see:

Tested on: Qwen3.5-9B (Q4_K_M + F16 MTP weights)

So you used mixed precision — Q4_K_M for the base model with F16 for the MTP head weights. My conversion quantized everything uniformly to Q4_K_M (including the MTP head), which would explain the lower acceptance rate.

Could you share how you produced the mixed-precision GGUF? Specifically:

  • Did you modify convert_hf_to_gguf.py to keep MTP tensors in F16 while quantizing the rest?
  • Or did you quantize with llama-quantize using a flag/config to exclude MTP tensors?

This would help reproduce the 82% number and would also be useful context for anyone else testing MTP.

@petter-b Apologies for the delay! To answer your question- I converted the model myself from the raw HuggingFace weights using the standard convert_hf_to_gguf.py script that comes with llama.cpp. This script reads the HF model files and outputs a GGUF, and I then ran llama-quantize on it to get Q4_K_M. No special steps were taken to preserve the MTP head in F16- the entire model including the MTP layers was quantized uniformly to Q4_K_M. The f32 tensors you'd see in the GGUF are just norm/embedding layers that llama-quantize leaves in f32 by default.

I should also mention that after more testing, the acceptance rate has been inconsistent — I've seen it range anywhere from 40% to 82%, and replicating the higher numbers consistently has proven difficult. It's been one of the harder things to track down on my end, so I wouldn't treat the 82% as a reliable baseline just yet.

However, I do not have the compute to attempt any of this in FP16! So- my methodology and approach was grounded in the limited compute that I have!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples model Model specific python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants