Skip to content

fix: Add glm4_moe_lite to MLA detection#32614

Merged
vllm-bot merged 10 commits intovllm-project:mainfrom
marksverdhei:fix/glm4-moe-mla-detection
Jan 23, 2026
Merged

fix: Add glm4_moe_lite to MLA detection#32614
vllm-bot merged 10 commits intovllm-project:mainfrom
marksverdhei:fix/glm4-moe-mla-detection

Conversation

@marksverdhei
Copy link
Copy Markdown
Contributor

@marksverdhei marksverdhei commented Jan 19, 2026

Summary

  • Add glm4_moe_lite and glm4_moe_lite_mtp to is_deepseek_mla() check in model_arch_config_convertor.py

GLM-4.7-Flash (glm4_moe_lite) uses Multi-head Latent Attention (MLA) via Glm4MoeLiteMLAAttention (which inherits from DeepseekV2MLAAttention) but was missing from the MLA detection.

Without this fix, vLLM falls back to standard KV caching instead of efficient MLA caching, resulting in ~4x higher KV cache memory usage.

Note: glm4_moe is intentionally NOT included as it uses standard attention (Glm4MoeAttention with vllm.attention.layer.Attention).

Co-authored with @mgoin

NOTE: SM100 has issues with support for this model with various MLA decode and prefill kernels, so the following changes were made to support default inference there:

  • Disable trtllm prefill and flashinfer prefill if we didn't find DeepSeek R1 compatible MLA dimensions
  • Disable flashinfer mla if we didn't find DeepSeek R1 compatible MLA dimensions
  • Explicitly enable cutlass mla so the block_size=128 gets enforced
  • This means that on SM100 we will run with CUTLASS_MLA decode and FA2 prefill by default for this model, I also tested that TRITON_MLA works.

Test Plan

  • Tested with marksverdhei/GLM-4.7-Flash-fp8 on 2x RTX 3090
  • Verified MLA is detected and efficient KV caching is used
  • Model runs with 14.7 GB VRAM per GPU

GLM-4.7-Flash (glm4_moe_lite) and GLM-4.6 (glm4_moe) use the same
Multi-head Latent Attention (MLA) architecture as DeepSeek models
but were not included in the is_deepseek_mla() check.

This caused vLLM to fall back to standard KV caching instead of
efficient MLA caching, resulting in significantly higher memory
usage (4x more KV cache than necessary).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to enable Multi-head Latent Attention (MLA) for GLM-4 MoE models. The change correctly adds glm4_moe_lite to the MLA detection logic, which is consistent with its implementation that supports MLA. However, adding glm4_moe is problematic as its current implementation in vLLM does not seem to support MLA, which could lead to runtime issues. I've provided a critical comment with a suggestion to only include glm4_moe_lite for now.

GLM-4.7-Flash (glm4_moe_lite) uses the same Multi-head Latent Attention
(MLA) architecture as DeepSeek models but was not included in the
is_deepseek_mla() check.

This caused vLLM to fall back to standard KV caching instead of
efficient MLA caching, resulting in significantly higher memory
usage (4x more KV cache than necessary).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
@marksverdhei marksverdhei force-pushed the fix/glm4-moe-mla-detection branch from 09521b2 to edaf4f7 Compare January 19, 2026 20:04
@marksverdhei marksverdhei changed the title fix: Add GLM-4 MoE models to MLA detection fix: Add glm4_moe_lite to MLA detection Jan 19, 2026
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for catching this!

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) January 20, 2026 15:49
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 20, 2026
@mgoin
Copy link
Copy Markdown
Member

mgoin commented Jan 20, 2026

@LucasWilkinson @MatthewBonanni When I run with this PR on B200, I get this error. I think we need to change the kernel support registration

(EngineCore_DP0 pid=1261690)   File "/home/mgoin/code/vllm/.venv/lib/python3.12/site-packages/flashinfer/decode.py", line 2491, in _check_trtllm_gen_mla_shape
(EngineCore_DP0 pid=1261690)     raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}")
(EngineCore_DP0 pid=1261690) ValueError: Expected qk_nope_head_dim == 128, got 192

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking merge while we investigate which MLA backends are/can be supported for this model. For instance forcing TRITON_MLA on B200 results in 0% on GSM8k

Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com>
auto-merge was automatically disabled January 21, 2026 15:17

Head branch was pushed to by a user without write access

@marksverdhei
Copy link
Copy Markdown
Contributor Author

Blocking merge while we investigate which MLA backends are/can be supported for this model. For instance forcing TRITON_MLA on B200 results in 0% on GSM8k

Could it be because of the known looping issues that Lama CPP also had? Like 0% because it doesn't arrive at an answer and only generates think tokens or did you control for it already? I have never contributed to a vLLM before, so I'm not completely familiar with the testing procedures

@mdierolf
Copy link
Copy Markdown

LGTM!

Works well on RTX 6000 Blackwell, model is unusable without this.

Nobody is going to run this small model on B200 anyway, merge it and fix B200 mañana!

Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin mentioned this pull request Jan 23, 2026
5 tasks
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good to go now with the Blackwell fixes, thanks for kicking this off @marksverdhei !

@github-project-automation github-project-automation bot moved this from In review to Ready in NVIDIA Jan 23, 2026
@vllm-bot vllm-bot merged commit 586a57a into vllm-project:main Jan 23, 2026
55 of 57 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 23, 2026
@marksverdhei
Copy link
Copy Markdown
Contributor Author

marksverdhei commented Jan 24, 2026

Thank you for merging! i can now call myself a proud contributor of vllm, with those two lines of strings! 😆 Way to kick off my 2026 new years resolutions!

cwazai pushed a commit to cwazai/vllm that referenced this pull request Jan 25, 2026
Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: 陈建华 <1647430658@qq.com>
@Xiaojinhua
Copy link
Copy Markdown

Is there a new version released after this fix, or only be installed through the latest source code now.

lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
@esmeetu
Copy link
Copy Markdown
Member

esmeetu commented Jan 28, 2026

@Xiaojinhua You can install the nightly version:

uv pip install -U vllm --pre \ 
--extra-index-url https://wheels.vllm.ai/nightly/cu129 \
--extra-index-url https://download.pytorch.org/whl/cu129 \
--index-strategy unsafe-best-match

@gaby
Copy link
Copy Markdown

gaby commented Feb 18, 2026

@esmeetu @marksverdhei even with these changes weeks ago, the glm-4.7-flash model does not start because transformers is pinned to <5.0.0.

Tldr: using nightly docker image

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 18, 2026

@gaby yes we still recommend in the recipe for the model to install transformers from source https://docs.vllm.ai/projects/recipes/en/latest/GLM/GLM.html

@gaby
Copy link
Copy Markdown

gaby commented Feb 18, 2026

@mgoin Sadly that doesn't work using vLLM with Docker. Since we rely on the image provided by the vLLM team.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 18, 2026

@gaby you can extend the existing image like this
image

https://x.com/thezachmueller/status/2014354173492432942?s=46&t=jLcDgQXDbYe6HgFmTNYgpg

@gaby
Copy link
Copy Markdown

gaby commented Feb 18, 2026

@mgoin Will give that a try, thanks!

@esmeetu
Copy link
Copy Markdown
Member

esmeetu commented Feb 18, 2026

@gaby You can try this image: vllm/vllm-openai:glm5

@gaby
Copy link
Copy Markdown

gaby commented Feb 18, 2026

@esmeetu thanks!

ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

10 participants