Skip to content

feat: Step3.5 MTP Support#20981

Closed
forforever73 wants to merge 4 commits intoggml-org:masterfrom
stepfun-ai:step3p5-mtp
Closed

feat: Step3.5 MTP Support#20981
forforever73 wants to merge 4 commits intoggml-org:masterfrom
stepfun-ai:step3p5-mtp

Conversation

@forforever73
Copy link
Contributor

Summary

This PR introduces an end-to-end implementation of Step3.5 Multi-Token Prediction (MTP) in llama.cpp, covering model conversion, loading, runtime graph execution, speculative decoding, server integration, and quantization compatibility.

The current implementation is intentionally specialized for Step3.5 and focuses on enabling a complete and usable MTP pipeline within the existing architecture. It allows early experimentation with Step3.5-style MTP and provides a concrete reference for evaluating performance and behavior in practice.

This PR does not attempt to define a generalized MTP abstraction. Instead, it provides a concrete baseline that can be used for experimentation, benchmarking, and informing future design directions toward a more general solution.

Implemented

  • Conversion and tensor mapping - export the Step3.5 nextn / shared-head tensors needed by the current runtime path and register them as proper per-layer tensors, currently keeping the first MTP layer only.
  • Model loading - load the Step3.5 MTP weights when -mtp is enabled instead of always pulling the extra tensors into the runtime.
  • Runtime graph - add a dedicated Step3.5 MTP graph path in the iswa builder, while refactoring layer construction so the normal decoder path and the MTP path share the same core logic.
  • Speculation - add an MTP-specific speculative decoding state machine with retained prefix handling, staged first-pass sources, and verifier hidden-state handoff.
  • SWA guard - add SWA / iSWA guard and rollback-related handling to preserve correct MTP behavior across SWA window boundaries, especially when advancing the draft would evict cache state that cannot be recovered after accept.
  • Server path - wire prompt hidden-state collection, MTP draft setup, accept / rollback flow, and prefix reuse into the server implementation.
  • CLI compatibility - expose the feature through CLI options and keep the speculative example usable without a separate draft model.
  • Quantization compatibility - make the quantization path recognize the new MTP tensors and avoid incompatible imatrix requirements for them.

Not Yet Completed

  • Multi-layer MTP - the current Step3.5 runtime only uses the first MTP layer.
  • Generalization - the implementation is still Step3.5-oriented and is not yet shaped into a more general MTP framework.
  • Cache reuse - only continuous prefix reuse is supported for MTP right now; the prompt-cache reuse path is currently disabled, and the more general cache reuse path is not handled yet.
  • Shared draft context - the server still uses one draft context per slot and has not been reworked into a shared / batched draft-context design.

Testing

These numbers are taken from the server log line:

eval time = ... ms / ... tokens (... ms per token, ... tokens per second)

For these runs, -mtp enables the Step3.5 MTP path introduced in this PR: it loads the MTP weights and makes the server run the MTP speculative decoding flow being measured below.

Decode Speed on 8xH200

nohup env ./llama-server \
  -m ../../../step3p5_flash_fp16.gguf \
  --host 0.0.0.0 --port 8080 \
  --device CUDA0,CUDA1,CUDA2,CUDA3,CUDA4,CUDA5,CUDA6,CUDA7 \
  --split-mode layer --tensor-split 1,1,1,1,1,1,1,1 \
  -c 262144 -ctk f16 -ctv f16 \
  -np 1 -cb -b 4096 -ub 2048 \
  --reasoning-format none \
  -mtp --draft 3 -cram 0 \
  > ./llama-server_mtp.log 2>&1 &

Model weights: fp16, KV cache: fp16

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 57.16 78.75 1.38x 0.86648
32k 2 57.16 89.35 1.56x 0.75817
32k 3 57.16 93.00 1.62x 0.70032
64k 1 56.30 77.85 1.38x 0.87181
64k 2 56.30 90.19 1.60x 0.76775
64k 3 56.30 92.33 1.64x 0.70326

Model weights: fp16, KV cache: q8_0

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 52.07 73.04 1.40x 0.89131
32k 2 52.07 87.14 1.67x 0.76893
32k 3 52.07 88.68 1.70x 0.70144

Model weights: fp16, KV cache: q4_0

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 50.87 69.35 1.36x 0.88316
32k 2 50.87 84.21 1.65x 0.78450
32k 3 50.87 87.93 1.72x 0.71941

Model weights: Q4_K_S, KV cache: fp16

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 73.20 95.07 1.29x 0.88067
32k 2 73.20 101.29 1.38x 0.75524
32k 3 73.20 103.15 1.40x 0.71893
64k 1 71.76 93.96 1.31x 0.88330
64k 2 71.76 101.13 1.40x 0.76799
64k 3 71.76 100.35 1.40x 0.70598
128k 1 71.31 92.39 1.34x 0.88361
128k 2 71.31 100.54 1.41x 0.76811
128k 3 71.31 99.57 1.40x 0.70351

Model weights: IQ4_XS, KV cache: fp16

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 71.30 93.26 1.31x 0.86898
32k 2 71.30 102.55 1.44x 0.76226
32k 3 71.30 102.23 1.43x 0.70322
64k 1 69.91 91.85 1.31x 0.87270
64k 2 69.91 101.54 1.45x 0.76892
64k 3 69.91 101.13 1.45x 0.69951

Decode Speed on Mac Studio

Performance on Mac Studio is currently weaker than I expected. Profiling so far suggests that the main-model verify path on the Metal backend is relatively expensive. If anyone has suggestions or useful leads here, feedback would be very welcome.

nohup ./llama-server \
  -m ../../../step3p5_flash_IQ4_XS.gguf \
  --host 0.0.0.0 --port 8080 \
  -c 131072 \
  -ctk f16 -ctv f16 \
  -np 1 -cb \
  -b 4096 -ub 2048 \
  --jinja \
  --reasoning-format none \
  -mtp --draft 3 -cram 0 \
    > ./llama-server_mtp.log 2>&1 &

Model weights: IQ4_XS, KV cache: fp16

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 35.38 45.61 1.29x 0.88881
32k 2 35.38 42.70 1.21x 0.75814
32k 3 35.38 40.03 1.13x 0.68679

Model weights: Q4_K_S, KV cache: fp16, MTP KV cache: q4_0

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 42.99 48.34 1.12x 0.87733
32k 2 42.99 45.25 1.05x 0.77245
32k 3 42.99 40.31 0.94x 0.68601
64k 1 39.12 44.84 1.15x 0.87963
64k 2 39.12 41.91 1.07x 0.77081
64k 3 39.12 36.03 0.92x 0.68636

Decode Speed on DGX Spark

nohup env GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 ./llama-server \
  --mlock \
  -m ../../../step3p5_flash_mtp_Q4_K_S.gguf \
  --host 0.0.0.0 --port 8080 \
  -c 131072 \
  -ctk q8_0 -ctv q8_0 \
  -ctkd q4_0 -ctvd q4_0 \
  -np 1 -cb \
  -b 2048 -ub 1024 \
  --jinja \
  --reasoning-format none \
  -mtp --draft 3 -cram 0 \
  > ./llama-server_mtp.log 2>&1 &

Model weights: IQ4_XS, KV cache: fp16

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 15.04 20.76 1.38x 0.87610
32k 2 15.04 21.13 1.40x 0.76194
32k 3 15.04 20.12 1.33x 0.67572

Model weights: Q4_K_S, KV cache: q8_0, MTP KV cache: q4_0

Max tokens Draft max No MTP decode tok/s MTP decode tok/s Speedup Acceptance rate
32k 1 19.71 24.73 1.25x 0.87021
32k 2 19.71 23.69 1.20x 0.75541
32k 3 19.71 23.07 1.17x 0.70029

Related Work

@forforever73 forforever73 requested review from a team, CISC and ggerganov as code owners March 25, 2026 08:41
@forforever73 forforever73 marked this pull request as draft March 25, 2026 08:41
@pwilkin
Copy link
Member

pwilkin commented Mar 25, 2026

Please adhere to the Contribution Guidelines when submitting PRs. Submissions without the relevant disclaimers will be auto-rejected.

@pwilkin pwilkin closed this Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants