-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[Bug] Cutlass_MLA backend can't run with tp8 #6096
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
Cutlass MLA backend can only run when dp_size is equal to tp_size.
If launching deepseek-v3 with --tp 8, not enabling dp attention, the following bug occurs:
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 69, in forward
return self.forward_decode(
^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/cutlass_mla_backend.py", line 270, in forward_decode
o = cutlass_mla_decode(
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/sgl_kernel/attention.py", line 76, in cutlass_mla_decode
assert H == 128, f"H must be 128, but got {H}"
^^^^^^^^
AssertionError: H must be 128, but got 16Reproduction
When only use tp, there will be a bug:
python3 -m sglang.bench_one_batch --model-path /dev/shm/DeepSeek-V3 --tp 8 --batch 16 --attention-backend cutlass_mla --page-size 128 --input-len 1024 --output-len 128When use tp and dp together, the bug disappears:
python3 -m sglang.bench_one_batch --model-path /dev/shm/DeepSeek-V3 --enable-dp-attention --tp 8 --dp 8 --batch 16 --attention-backend cutlass_mla --page-size 128 --input-len 1024 --output-len 128Environment
Nvidia 8*B200, Cuda 12.8
aiohappyeyeballs 2.6.1
aiohttp 3.11.18
aiosignal 1.3.2
airportsdata 20250224
anaconda-anon-usage 0.7.0
anaconda-cli-base 0.5.2
anaconda-client 1.13.0
annotated-types 0.6.0
anyio 4.9.0
asttokens 3.0.0
attrs 24.3.0
beautifulsoup4 4.12.3
black 25.1.0
blobfile 3.0.0
boltons 23.0.0
brotlipy 0.7.0
certifi 2025.4.26
cffi 1.15.1
cfgv 3.4.0
chardet 4.0.0
charset-normalizer 2.0.4
click 8.1.8
cloudpickle 3.1.1
cmake 3.18.4
colorama 0.4.6
compressed-tensors 0.9.4
conda 23.5.2
conda-build 24.3.0
conda-content-trust 0.1.3
conda_index 0.6.0
conda-libmamba-solver 23.5.0
conda-package-handling 2.1.0
conda_package_streaming 0.8.0
cryptography 39.0.1
cuda-bindings 12.8.0
cuda-python 12.8.0
datasets 3.5.1
decorator 5.2.1
decord 0.6.0
defusedxml 0.7.1
dill 0.3.8
diskcache 5.6.3
distlib 0.3.9
einops 0.8.1
executing 2.2.0
fastapi 0.115.12
fastjsonschema 2.20.0
filelock 3.17.0
flashinfer-python 0.2.5
frozenlist 1.6.0
fsspec 2024.10.0
h11 0.16.0
hf_transfer 0.1.9
huggingface-hub 0.30.2
icdiff 2.0.7
identify 2.6.10
idna 3.4
iniconfig 2.1.0
interegular 0.3.3
ipython 9.2.0
ipython_pygments_lexers 1.1.1
isort 6.0.1
jedi 0.19.2
Jinja2 3.1.6
jsonpatch 1.32
jsonpointer 2.1
jsonschema 4.23.0
jsonschema-specifications 2023.7.1
jupyter_core 5.7.2
lark 1.2.2
libarchive-c 5.1
libmambapy 1.4.1
llguidance 0.7.19
lxml 5.4.0
markdown-it-py 2.2.0
MarkupSafe 3.0.2
matplotlib-inline 0.1.7
mdurl 0.1.0
menuinst 2.2.0
modelscope 1.25.0
mpmath 1.3.0
msgpack 1.0.3
multidict 6.4.3
multiprocess 0.70.16
mypy_extensions 1.1.0
nanobind 2.7.0
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.4.2
ninja 1.11.1.4
nodeenv 1.9.1
numpy 2.1.2
nvidia-cublas-cu12 12.8.3.14
nvidia-cuda-cupti-cu12 12.8.57
nvidia-cuda-nvrtc-cu12 12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12 9.8.0.87
nvidia-cufft-cu12 11.3.3.41
nvidia-cufile-cu12 1.13.0.11
nvidia-curand-cu12 10.3.9.55
nvidia-cusolver-cu12 11.7.2.55
nvidia-cusparse-cu12 12.5.7.53
nvidia-cusparselt-cu12 0.6.3
nvidia-ml-py 12.570.86
nvidia-nccl-cu12 2.26.2.post1
nvidia-nvjitlink-cu12 12.8.61
nvidia-nvtx-cu12 12.8.55
orjson 3.10.18
outlines 0.1.11
outlines_core 0.1.26
packaging 25.0
pandas 2.2.3
parso 0.8.4
partial-json-parser 0.2.1.1.post5
pathspec 0.12.1
pexpect 4.9.0
pillow 11.0.0
pip 23.1.2
pkginfo 1.12.0
platformdirs 4.3.7
pluggy 1.5.0
pre_commit 4.2.0
prometheus_client 0.21.1
prompt_toolkit 3.0.51
propcache 0.3.1
psutil 5.9.0
ptyprocess 0.7.0
pure_eval 0.2.3
pyarrow 20.0.0
pycosat 0.6.4
pycountry 24.6.1
pycparser 2.21
pycryptodomex 3.22.0
pydantic 2.10.3
pydantic_core 2.27.1
pydantic-settings 2.6.1
Pygments 2.19.1
pynvml 12.0.0
pyOpenSSL 23.0.0
PySocks 1.7.1
pytest 8.3.5
python-dateutil 2.9.0.post0
python-dotenv 1.1.0
python-multipart 0.0.20
pytorch-triton 3.3.0+git96316ce5
pytz 2024.1
PyYAML 6.0.2
pyzmq 26.4.0
readchar 4.0.5
referencing 0.30.2
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
rpds-py 0.22.3
ruamel.yaml 0.17.21
safetensors 0.5.3
scikit_build_core 0.11.2
sentencepiece 0.2.0
setproctitle 1.3.6
setuptools 75.0.0
sgl-kernel 0.1.1
sglang 0.4.6.post2 /sgl-workspace/sglang/python
shellingham 1.5.0
six 1.16.0
sniffio 1.3.1
soundfile 0.13.1
soupsieve 2.5
stack-data 0.6.3
starlette 0.46.2
sympy 1.13.3
tiktoken 0.9.0
tokenizers 0.21.1
tomli 2.0.1
toolz 0.12.0
torch 2.8.0.dev20250501+cu128
torchao 0.10.0
torchaudio 2.6.0.dev20250501+cu128
torchvision 0.22.0.dev20250501+cu128
tqdm 4.67.1
traitlets 5.14.3
transformers 4.51.1
typer 0.9.0
typing_extensions 4.12.2
tzdata 2025.2
urllib3 1.26.16
uv 0.7.2
uvicorn 0.34.2
uvloop 0.21.0
virtualenv 20.30.0
wcwidth 0.2.13
wheel 0.41.0
xgrammar 0.1.17
xxhash 3.5.0
yarl 1.20.0
zstandard 0.19.0