Skip to content

Add MTP decoding support for GLM-4.x MoE#1270

Merged
ikawrakow merged 20 commits intoikawrakow:mainfrom
SamuelOliveirads:feat-glm-mtp
Feb 22, 2026
Merged

Add MTP decoding support for GLM-4.x MoE#1270
ikawrakow merged 20 commits intoikawrakow:mainfrom
SamuelOliveirads:feat-glm-mtp

Conversation

@SamuelOliveirads
Copy link
Copy Markdown
Contributor


This is a follow-up to discussion #1228 that allows using MTP for the GLM 4.5/4.6/4.7 family. This PR focuses primarily on the implementation logic, as performance currently has some regressions due to the necessity of running an additional small layer as a secondary model.

Currently it support four arguments:

  • -mtp or --multi-token-prediction
  • -no-mtp or --no-multi-token-prediction
  • --draft-max and --draft-p-min (same as used for speculative decoding)

It took some time to understand the differences between upstream and ik_llama, most of the effort went into adapting the KV cache logic for MTP. The most notable difference is the acceptance rate, which ranges from 30-50% here, compared to 50-70% upstream. I don't know exactly what is causing the degradation, so I need to investigate further.

As small examples, this is what I get using hybrid spec with layer mode in creative tasks where MTP have the worst case scenario:

1) GLM 4.5 Air IQ4_XS

a) Without mtp

prompt eval time =   38433.44 ms / 15144 tokens (    2.54 ms per token,   394.03 tokens per second)
       eval time =  100920.95 ms /  1122 tokens (   89.95 ms per token,    11.12 tokens per second)
      total time =  139354.39 ms / 16266 tokens

b) -mtp --draft-max 1 --draft-p-min 0.85

prompt eval time =   36291.64 ms / 15144 tokens (    2.40 ms per token,   417.29 tokens per second)
       eval time =  197110.12 ms /   929 tokens (  212.17 ms per token,     4.71 tokens per second)
      total time =  233401.76 ms / 16073 tokens
draft acceptance rate = 0.30704 (  218 accepted /   710 generated)

c) -mtp --draft-max 2 --draft-p-min 0.85

prompt eval time =   48187.73 ms / 15145 tokens (    3.18 ms per token,   314.29 tokens per second)
       eval time =  265758.77 ms /  1157 tokens (  229.70 ms per token,     4.35 tokens per second)
      total time =  313946.50 ms / 16302 tokens
draft acceptance rate = 0.26601 (  299 accepted /  1124 generated)

2) GLM 4.7 IQ2_KS

a) Without mtp

prompt eval time =  139098.70 ms / 15144 tokens (    9.19 ms per token,   108.87 tokens per second)
       eval time =   97702.43 ms /   453 tokens (  215.68 ms per token,     4.64 tokens per second)
      total time =  236801.14 ms / 15597 tokens

b) -mtp --draft-max 1 --draft-p-min 0.85

prompt eval time =  182407.79 ms / 15144 tokens (   12.04 ms per token,    83.02 tokens per second)
       eval time =  193361.09 ms /   511 tokens (  378.40 ms per token,     2.64 tokens per second)
      total time =  375768.88 ms / 15655 tokens
draft acceptance rate = 0.19438 (   83 accepted /   427 generated)

c) -mtp --draft-max 2 --draft-p-min 0.85

prompt eval time =  174532.26 ms / 15144 tokens (   11.52 ms per token,    86.77 tokens per second)
       eval time =  167451.67 ms /   444 tokens (  377.14 ms per token,     2.65 tokens per second)
      total time =  341983.92 ms / 15588 tokens
draft acceptance rate = 0.14881 (   75 accepted /   504 generated)

d) -mtp --draft-max 3 --draft-p-min 0.85

prompt eval time =  182459.62 ms / 15144 tokens (   12.05 ms per token,    83.00 tokens per second)
       eval time =  183348.40 ms /   442 tokens (  414.82 ms per token,     2.41 tokens per second)
      total time =  365808.02 ms / 15586 tokens
draft acceptance rate = 0.09569 (   60 accepted /   627 generated)

Some observations that I have, first in 4.5 Air I noticed that blk.46.nextn.embed_tokens.weight is constantly being copied between backends. I don't know exactly how to fix this; I suspected cb() could handle it, but that appears incorrect.

Also, is system_prompt_update in server-context.cpp still used? I couldn't trigger it, so I'm unsure if it needs a small update for MTP to run there as well.

@saood06
Copy link
Copy Markdown
Collaborator

saood06 commented Feb 14, 2026

Also, is system_prompt_update in server-context.cpp still used? I couldn't trigger it, so I'm unsure if it needs a small update for MTP to run there as well.

Even if it is, it has always been basically useless, see #199.

@ikawrakow
Copy link
Copy Markdown
Owner

Sorry for the delay reviewing this. I have been focusing on Qwen3-Next optimizations in the last few days.

Is there a GLM-4.5-AIR model that I can download for testing that has not crippled MTP tensors included? (as normally they have been either excluded or quantized with very low bpw quants to reduce model size)

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

Is there a GLM-4.5-AIR model that I can download for testing that has not crippled MTP tensors included? (as normally they have been either excluded or quantized with very low bpw quants to reduce model size)

Unsloth and Ubergarm preserve the MTP layer, but it is quantized at the same level as the base model. I typically use IQ4_XS from Unsloth so I can test in both projects (meaning the MTP layer is also at the Q4 level), but I believe using GLM-4.5-Air-Q8_0 would be better for testing.

@ikawrakow
Copy link
Copy Markdown
Owner

The PR LGTM. Can you resolve the conflicts? Thanks.

I played a bit with it. It looks like it is working, but performance goes down.
Acceptance rate seems quite low: 25-30% for single token, just 16% for 4 drafted tokens. Is this expected?

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

Sorry for the delay. The self-speculative logic changed quite a bit, so I took the opportunity to refactor MTP to fit the current common_speculative structure properly.

I played a bit with it. It looks like it is working, but performance goes down. Acceptance rate seems quite low: 25-30% for single token, just 16% for 4 drafted tokens. Is this expected?

Definitely not. Upstream typically sees a 60-70% acceptance rate for a single token. I am also seeing the lower rate here.

I have a strong suspicion about the cause (likely related to embeddings input), but debugging it will take some time. It might require adjustments to how embeddings are extracted during llama_decode.

I doubt anyone would want to use MTP in its current state given the performance hit, but the architectural changes in this PR are solid. Feel free to merge if you want the structural support in place while I investigate the acceptance rate issue in a follow-up PR.

@ikawrakow
Copy link
Copy Markdown
Owner

I see this warning when building:

/home/iwan/other/tmp/so_ik_llama.cpp/examples/mtmd/clip.cpp:38:62: warning: bitwise operation between different enumeration types ‘ggml_scale_mode’ and ‘ggml_scale_flag’ is deprecated [-Wdeprecated-enum-enum-conversion]
   38 | #define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
      |                                     ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Well, yes. One shouldn't bitwise-or two unrelated things.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

Well, this was from the Kimi port which is funny, in there the logic is the same so I'm curious if the warning is simply being ignored/suppressed upstream.

Lets not complicate this, I pushed a commit to fix the warning and also to fix another problem with the mtp param in common that would crash the --help arg.

Copy link
Copy Markdown
Owner

@ikawrakow ikawrakow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@firecoperana Do you want to look at it?

@ikawrakow ikawrakow merged commit 09a88c9 into ikawrakow:main Feb 22, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 7, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 7, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 8, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 8, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 9, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 9, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 9, 2026
…1270)""

This reverts commit 70833dd4316b55937c88647a0ffd68901012dc1c.
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
…1270)""

This reverts commit 70833dd4316b55937c88647a0ffd68901012dc1c.
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
…1270)""

This reverts commit 70833dd4316b55937c88647a0ffd68901012dc1c.
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 11, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 11, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 12, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 12, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 15, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 16, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 16, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 16, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 17, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 17, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 17, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 17, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 18, 2026
Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants