Skip to content

feat: mtp support dp-attention#6081

Merged
zhyncs merged 66 commits intosgl-project:mainfrom
u4lr451:feature_mtp_support_dp_attention
Jun 17, 2025
Merged

feat: mtp support dp-attention#6081
zhyncs merged 66 commits intosgl-project:mainfrom
u4lr451:feature_mtp_support_dp_attention

Conversation

@u4lr451
Copy link
Copy Markdown
Contributor

@u4lr451 u4lr451 commented May 7, 2025

Motivation

mtp support dp-attention

  • implemented MTP for DP-attention, also fixed related bugs [Bug] DP + MTP init failed with deepseek r1 #4783 [Bug] DP attention with Eagle worker raises AttributeError #4847 .
  • Enabled CUDA Graph support for both target and draft models at dp-attention.
  • Performance Optimizations: Refined gathered_buffer memory allocation during MTP on dp-attention, eliminates redundant GPU allocation (previously scaled by --speculative-num-draft-tokens multiplier),prevents unnecessary memory usage in both target and draft models infence. Benefit for DP concurrency by reducing memory contention, and decreases all_reduce communication overhead.

Checklist

Accuracy

python3 bench_sglang.py  --data_dir data --nsub 20
  • Baseline
#two node 
export SGL_ENABLE_JIT_DEEPGEMM=0
python3 -m sglang.launch_server --model-path /sgl-workspace//DeepSeek-V3-0324 --dist-init-addr ${HOST_IP}:20000 --nnodes 2 --node-rank ${RANK}  --trust-remote-code --served-model-name DeepSeek-V3-0324 --context-length 65536 --tensor-parallel-size 16 --stream-output --host 0.0.0.0 --port 30000 --watchdog-timeout 240 --disable-radix-cache --schedule-policy fcfs --chunked-prefill-size 32768 --max-running-requests 24 --disable-overlap-schedule --attention-backend flashinfer --enable-metrics --log-requests

Average accuracy: 0.887

  • mtp with dp-attention
#two node 
export SGL_ENABLE_JIT_DEEPGEMM=0
python3 -m sglang.launch_server --model-path /sgl-workspace//DeepSeek-V3-0324 --dist-init-addr ${HOST_IP}:20000 --nnodes 2 --node-rank ${RANK} --trust-remote-code --served-model-name DeepSeek-V3-0324 --context-length 65536 --tensor-parallel-size 16 --stream-output --host 0.0.0.0 --port 30000 --watchdog-timeout 240 --disable-radix-cache --schedule-policy fcfs --chunked-prefill-size 32768 --max-running-requests 24 --disable-overlap-schedule --attention-backend flashinfer --disable-cuda-graph-padding --mem-fraction-static 0.60 --speculative-algo NEXTN --speculative-draft /sgl-workspace/SGLang/DeepSeek-V3-0324-NextN --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --enable-metrics --log-requests --disable-cuda-graph --enable-nan-detection --enable-dp-attention --dp-size 8

Average accuracy: 0.887

@u4lr451 u4lr451 force-pushed the feature_mtp_support_dp_attention branch from 7b9e06d to 9441246 Compare May 7, 2025 17:50
@u4lr451 u4lr451 changed the title feat: mtp support dp-attention (#6080) feat: mtp support dp-attention May 7, 2025
@ch-wan ch-wan linked an issue May 7, 2025 that may be closed by this pull request
2 tasks
@lambert0312
Copy link
Copy Markdown
Contributor

lambert0312 commented May 8, 2025

After testing, the error is as follows:

Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 314, in __init__
    self.capture()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 405, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 523, in capture_one_batch_size
    torch.cuda.synchronize()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 985, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2266, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 272, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 85, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 190, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 239, in initialize
    self.init_cuda_graphs()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1025, in init_cuda_graphs
    self.cuda_graph_runner = CudaGraphRunner(self)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 316, in __init__
    raise Exception(
Exception: Capture cuda graph failed: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Possible solutions:
1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
3. disable torch compile by not using --enable-torch-compile
4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)
Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose

When --disable-cuda-graph is set, start normally.

@u4lr451
Copy link
Copy Markdown
Contributor Author

u4lr451 commented May 8, 2025

After testing, the error is as follows:

Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 314, in __init__
    self.capture()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 405, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 523, in capture_one_batch_size
    torch.cuda.synchronize()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 985, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

lambert0312 The latest commit( #5256543 has fixed this bug. Thanks!

@lambert0312
Copy link
Copy Markdown
Contributor

lambert0312 commented May 8, 2025

lambert0312 The latest commit( #5256543 has fixed this bug. Thanks!

@u4lr451 Great, it has been verified to work properly, but the speed is much slower than when dp-attention is not enabled. Why is this?

@u4lr451
Copy link
Copy Markdown
Contributor Author

u4lr451 commented May 9, 2025

lambert0312 The latest commit( #5256543 has fixed this bug. Thanks!

@u4lr451 Great, it has been verified to work properly, but the speed is much slower than when dp-attention is not enabled. Why is this?

@lambert0312 The choice between pure TP, or enable DP-attention depends on multiple factors, such as GPU model, request batch size/concurrency, DP parameters, business SLA requirements,etc.
Additionally:

  • MTP itself increases memory and compute overhead, The speedup ratio of DP-attention + MTP correlates with multiple factors : such as a) MTP acceptance rate , b) workloads, batch size/concurrency, c) Request balancing across DP workers .
  • When using DP-attention + MTP, the optimal capture_bs for CUDA graphs may differ.

@u4lr451 u4lr451 force-pushed the feature_mtp_support_dp_attention branch 3 times, most recently from adcb787 to b61e3a6 Compare May 12, 2025 15:55
@u4lr451
Copy link
Copy Markdown
Contributor Author

u4lr451 commented May 12, 2025

@ch-wan @fzyzcjy @merrymercy @zhyncs hi, would someone mind checking if this is ready to merge? Thanks!

@zhangxiaolei123456
Copy link
Copy Markdown
Contributor

zhangxiaolei123456 commented May 13, 2025

Open DP attention, MTP, cuda graph found that the performance dropped very much, analyzed and found that it was because the reception rate dropped very much. This caused the throughput to drop.
disable cuda graph:
accept len: 3.69, gen throughput (token/s): 137.83

[2025-05-13 02:32:12 DP5 TP5] Decode batch. #running-req: 8, #token: 35785, token usage: 0.27, accept len: 3.64, gen throughput (token/s): 135.85, #queue-req: 0
[2025-05-13 02:32:12 DP6 TP6] Decode batch. #running-req: 8, #token: 35896, token usage: 0.27, accept len: 3.62, gen throughput (token/s): 135.27, #queue-req: 0
[2025-05-13 02:32:12 DP7 TP7] Decode batch. #running-req: 8, #token: 36092, token usage: 0.27, accept len: 3.69, gen throughput (token/s): 137.83, #queue-req: 0

open cuda graph:
accept len: 2.12, gen throughput (token/s): 83.78

[2025-05-13 01:57:33 DP1 TP1] Decode batch. #running-req: 8, #token: 29915, token usage: 0.22, accept len: 1.55, gen throughput (token/s): 61.31, #queue-req: 0
[2025-05-13 01:57:37 DP0 TP0] Decode batch. #running-req: 8, #token: 30629, token usage: 0.23, accept len: 2.12, gen throughput (token/s): 83.78, #queue-req: 0

@u4lr451

@miter6
Copy link
Copy Markdown
Contributor

miter6 commented Jun 16, 2025

bugs:
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [64,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [65,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [66,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [67,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [68,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [69,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [70,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [71,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [208,0,0], thread: [72,0,0] Assertion srcIndex < srcSelectDimSize failed.

Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this excellent contribution. It represents a major optimization for boosting the throughput of DeepSeek-V3/R1, with its correctness and effectiveness verified by many contributors and users from the community. The current implementation looks solid to me.

For future PRs, consider these remaining optimizations:

  • Enabling CUDA graphs for idle batches during verify or draft_after_decode. This was previously implemented but reverted by me to unblock merging this PR.
  • Migrating DP attention support to #6995. The current setup requires capturing 3 CUDA graphs and creating 3 gathered_buffers, which consumes unnecessary memory.
  • Reducing scheduling overhead. The current approach may invoke all_gather_into_tensor twice to check for idle batches, potentially lowering end-to-end throughput in some scenarios.

@Xuweijia-buaa
Copy link
Copy Markdown

Xuweijia-buaa commented Jun 17, 2025

when I use following args for DeepSeek-R1 model, not use another draft model, like this:
https://docs.sglang.ai/references/deepseek.html#multi-token-prediction

--speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2

raise such Error:
AttributeError: 'DeepseekModelNextN' object has no attribute 'layers'

do you know why and how to fix it?

complete logs are:
File "sglang/python/sglang/srt/managers/scheduler.py", line 2576, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
File "sglang/python/sglang/srt/managers/scheduler.py", line 325, in init
self.draft_worker = EAGLEWorker(
File "sglang/python/sglang/srt/speculative/eagle_worker.py", line 124, in init
super().init(
File "sglang/python/sglang/srt/managers/tp_worker.py", line 78, in init
self.model_runner = ModelRunner(
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 215, in init
self.initialize(min_per_gpu_memory)
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 256, in initialize
self.load_model()
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 550, in load_model
self.model = get_model(
File "sglang/python/sglang/srt/model_loader/init.py", line 22, in get_model
return loader.load_model(
File "sglang/python/sglang/srt/model_loader/loader.py", line 516, in load_model
model.post_load_weights()
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1784, in post_load_weights
self.model.layers[layer_id].self_attn
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1940, in getattr
raise AttributeError(
AttributeError: 'DeepseekModelNextN' object has no attribute 'layers'

I use
--load-format dummy

image

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Jul 8, 2025

@Xuweijia-buaa see this: #7506

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] mtp support dp-attention