Skip to content

Conversation

@lixiaolx
Copy link
Contributor

@lixiaolx lixiaolx commented Oct 24, 2025

Motivation

Currently, under deepseek3.2-DSA, prefill-ttft of long text sequences takes a long time. Introducing context parallel can reduce ttft.
Main design ideas:
image

Taking TP=EP=4 and DP=2 as an example (CP_SIZE==ATTEN_TP_SIZE):
Each DP accepts an independent request.
Within each DP, after embedding, the hidden state (batchseq_len, H) is split into context parallel segments. After splitting, the dimension of each ATTEN_TP data becomes (batchseq_len/CP_SIZE, H). The Moe part uses Deepep, and the input data dimension is also (batch*seq_len/CP_SIZE, H). This process is repeated layer_nums times. Finally, an allgather communication is performed on the results to ensure that each rank of atten_tp receives complete hidden states.

Note:

  1. During the entire MLA calculation process, the weights of each TP are not split by TP for related gemm calculations (qkv, o_proj, etc.).
  2. The relevant parts of moe_dense_tp_size==1, prepare_attn, prepare_mlp, and postprocess_layer only perform calculations such as layer norm; additional communication is not needed.

Explanation of main changes:
image

  • Attention calculations need to follow causal attention principles. If the CP (Context Parallel) is simply sliced ​​according to rank order, it may face a computational load balancing problem. For example, the first rank might focus on fewer historical key-value tokens, resulting in less computation; the last rank might focus on more, resulting in more computation. To mitigate the impact of uneven load, the input_hidden_state needs to be sliced ​​by cp_size*2. In the specific CP slicing implementation, to ensure that the computation time for each rank within the DP (Data parallel) is as similar as possible, the sliced ​​blocks are reassembled.
  • As shown in the figure (DP_ATTEN_TP==CP_SIZE==4). Rank 0 calculates blocks 0 and 7, rank 1 calculates blocks 1 and 6, rank 2 calculates blocks 2 and 5, and rank 3 calculates blocks 3 and 4.
image
  • In the ATTEN calculation section, because the output after CP splitting becomes (batch*seqlen/CP_SIZE, H), and each rank only contains a portion of the kv_cache, the entire kv_cache is needed for ATTEN calculation. Therefore, allgather communication is required. Furthermore, since blocks are reassembled during CP splitting, rerange is introduced to ensure the order of the kv_cache.
  • In actual ATTEN calculation, the q data size for each rank is 1/CP, and the kv_cache is the entire request.

Current description:
Currently only single batch and single machine processing is supported
Function switch --enable-nsa-prefill-context-parallel (default false-off)
test command
start:
python3 -m sglang.launch_server --model-path $MODEL_PATH --dp 8 --nnodes 1 --enable-dp-attention --node-rank 0 --trust-remote-code \ --dist-init-addr 0.0.0.0:6432 --port 8000 --host 0.0.0.0 --attention-backend nsa --nsa-prefill flashmla_sparse --nsa-decode flashmla_sparse \ --max-total-tokens 128000 --enable-metrics --mem-fraction-static 0.8 --max-running-requests 8 --enable-cache-report --page-size 64 \ --tp-size 8 --ep-size 8 --skip-server-warmup --disable-overlap-schedule --decode-log-interval 1 --moe-a2a-backend deepep \ --speculative-algorithm EAGLE --speculative-draft-model-path $DRAFT_MODEL_PATH \ --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2

curl:
curl http://127.0.0.1:8000/v1/completions -H "Content-Type: application/json" -d '{"model": "ds32-model", "prompt": "Write an ad copy for a new product, a digital photo frame that connects to your social media accounts and displays your photos. Respond with at most 150 words.", "max_tokens": 300, "temperature": 0, "stream": false }

ON enable-nsa-prefill-context-parallel:
The ad copy should be targeted at young adults and should highlight the product's unique features.\n\nCapture your life's best moments, not just on your phone, but in your space. Introducing the SocialFrame, the digital photo frame that brings your social media to life.\n\nIt automatically syncs with your Instagram and Facebook, creating a living gallery of your adventures, friends, and family. No more tedious uploading! Watch as new posts from your favorite people appear, keeping you connected to their lives in a beautiful, tangible way.\n\nThe sleek, modern design fits any decor, and the high-resolution display makes every memory shine. Give your photos the spotlight they deserve.\n\nTurn your feed into your frame. Get the SocialFrame today

OFF enable-nsa-prefill-context-parallel:
The ad copy should be targeted at young adults and should highlight the product's unique features.\n\nCapture your life's best moments, not just on your phone, but in your home. Introducing the SocialFrame, the digital photo frame that brings your social media to life.\n\nIt automatically syncs with your Instagram and Facebook albums, creating a constantly evolving gallery of your favorite memories. No more manual uploads! See your latest adventures, group shots, and everyday joys displayed in stunning HD.\n\nPerfect for your desk, your nightstand, or your living room, the SocialFrame is more than a frame—it's a live stream of your story. Share smiles, relive laughs, and keep your cherished connections close.\n\nDon't just store your photos. Celebrate them. Get your SocialFrame today

Accuracy
use bench_mark:
python3 benchmark/gsm8k/bench_sglang.py --host http://127.0.0.1 --port 8000 --num-questions 200
image

Modifications

Accuracy Tests

Checklist

  • p/d disaggregation
  • Support multiple batch-prefills
  • Multiple machines enable contex parallel

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @lixiaolx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces foundational support for context parallelism (CP) within the DeepSeekV3.2-DSA model architecture. The primary goal is to enhance performance by reducing the time to first token (prefill-ttft) for extended input sequences. This is achieved by enabling a new environment variable, SGLANG_USE_DP_CP_AG_AFTER_DSA, which triggers specialized handling of distributed communication and tensor operations. The changes involve adapting data and tensor parallelism mechanisms, introducing new utilities for tensor reorganization, and modifying the model's forward pass to process segments of long sequences in parallel. While this PR lays the groundwork, it acknowledges current limitations, such as single-batch processing and specific hardware configurations.

Highlights

  • Context Parallelism for DeepSeekV3.2-DSA: Introduced initial support for context parallelism (CP) specifically for the DeepSeekV3.2-DSA model, aiming to reduce prefill-ttft (time to first token) for long text sequences.
  • New Environment Variable: Added a new boolean environment variable, SGLANG_USE_DP_CP_AG_AFTER_DSA, to enable and control the context parallel optimization.
  • Distributed Communication Adjustments: Modified existing data parallelism (DP) and tensor parallelism (TP) communication patterns, such as all-gather and scatter operations, to conditionally adapt or bypass when context parallelism is active.
  • Tensor Reorganization Utility: Implemented a new utility function, attn_tp_all_gather_reorgan_into_tensor, to efficiently gather and reorder tensors across attention tensor parallel ranks for CP.
  • Model Forward Pass Integration: Integrated context parallel logic into the model's forward pass, affecting attention and MLP layers, to handle segmented input data and ensure correct processing across parallel ranks.
  • Input Preparation Utilities: Added new functions like calculate_cp_seq_idx and prepare_input_dp_with_cp_dsa to manage the complex indexing, splitting, and preparation of input data for context parallel execution.
  • Current Limitations: Noted that the current implementation of context parallelism primarily supports single-batch prefill and requires a tp_size of 8, with multi-batch support still a pending item.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces context parallelism for deepseek-v3.2-DSA models to reduce the time-to-first-token for long sequences. The changes are controlled by a new environment variable and involve modifications across the attention, communication, and model-specific layers. The implementation is quite extensive and seems to correctly follow the context parallelism pattern. However, it currently has some limitations, such as only supporting single-batch prefill and being tied to a specific 8-GPU configuration. My review includes suggestions for improving code clarity, fixing a critical typo, and cleaning up some leftover development code. Overall, this is a valuable performance enhancement.

@Fridge003
Copy link
Collaborator

@lixiaolx lixiaolx force-pushed the dsa_cp_ver3 branch 2 times, most recently from a1b3e9e to fd8cb2d Compare October 27, 2025 10:03
@lixiaolx
Copy link
Contributor Author

@lixiaolx Please fix lint with instructions here https://docs.sglang.ai/developer_guide/contribution_guide.html#format-code-with-pre-commit
@Fridge003 I have already fixed lint with instructions. Please help me find out what else needs to be changed.

@lixiaolx lixiaolx force-pushed the dsa_cp_ver3 branch 3 times, most recently from 388d7b5 to 21eba8d Compare October 28, 2025 09:49
@whybeyoung
Copy link
Collaborator

Current context parallel only support Single machine(tp_size == 8)?

@lixiaolx
Copy link
Contributor Author

Current context parallel only support Single machine(tp_size == 8)?

Yes, the multi-machine approach is still in testing and verification; we plan to submit a separate PR later.

@whybeyoung
Copy link
Collaborator

Current context parallel only support Single machine(tp_size == 8)?

Yes, the multi-machine approach is still in testing and verification; we plan to submit a separate PR later.

and also not support p/d disaggregation case. just have tested

@lixiaolx lixiaolx requested a review from yizhang2077 as a code owner October 30, 2025 12:50
@lixiaolx lixiaolx force-pushed the dsa_cp_ver3 branch 2 times, most recently from 277102f to fa091fe Compare October 30, 2025 13:09
@lixiaolx
Copy link
Contributor Author

lixiaolx commented Oct 30, 2025

Current context parallel only support Single machine(tp_size == 8)?

Yes, the multi-machine approach is still in testing and verification; we plan to submit a separate PR later.

and also not support p/d disaggregation case. just have tested

yes,future updates will support:

  1. Currently, the latest version supports single-machine dp1, dp2
  2. Removing synchronization logic from allgather.

@Fridge003
Copy link
Collaborator

Fridge003 commented Oct 31, 2025

Can you please test the accuracy of GPQA with this PR:

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3

The result should be about 0.80

Or other benchmark on long context, since this PR is for optimization under long context.

weights_prev, weights_next = torch.split(
weights, (weights.shape[0] + 1) // 2, dim=0
)
topk_result_prev = self._get_topk_ragged_with_cp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much performance benifit from ragged treatment, which splits hidden_states into prev and next parts?

@lixiaolx lixiaolx force-pushed the dsa_cp_ver3 branch 2 times, most recently from 18226c3 to 1ff644a Compare November 4, 2025 08:10
Copy link
Member

@sglang-bot sglang-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test case (to show the launch command). It does not need to be run in per-commit CI.

@FENP
Copy link

FENP commented Nov 10, 2025

@lixiaolx Nice work! CP is highly effective in reducing TTFT for long sequences, as it shards the input sequence across multiple devices. IIUC, this PR is designed for deepseekv3.2-DSA. It looks promising — have you thought about making it more extensible to benefit other models as well?
For example:

  1. Use a standalone --context-parallel-size parameter instead of --enable-nsa-prefill-context-parallel for better flexibility in configuring CP size
  2. Consider moving the hidden_states splitting logic to a more generic location instead of keeping it in model-specific files

FYI, Our team is working on CP support in vLLM, with current efforts centered on supporting GQA-based models (vllm-project/vllm#26864). We’d love to collaborate or help align the designs if helpful!

@lixiaolx
Copy link
Contributor Author

  1. Use a standalone --context-parallel-size parameter instead of --enable-nsa-prefill-context-parallel for better flexibility in configuring CP size

Our cp_size reuses atten_tp_size. Adjusting DP_size should meet your needs.

  1. Consider moving the hidden_states splitting logic to a more generic location instead of keeping it in model-specific files

The split-function migration is underway and will be submitted soon.

@github-actions github-actions bot added documentation Improvements or additions to documentation deepseek labels Nov 10, 2025
@FENP
Copy link

FENP commented Nov 12, 2025

Our cp_size reuses atten_tp_size.

Does this imply that tensor parallelism (TP) and context parallelism (CP) cannot coexist?
Would it be feasible to design CP as a separate parallelism config, orthogonal to both data parallelism (DP) and tensor parallelism (TP)?

@lixiaolx lixiaolx force-pushed the dsa_cp_ver3 branch 3 times, most recently from 66143da to ae6626f Compare November 16, 2025 11:45
@Fridge003 Fridge003 merged commit d368c74 into sgl-project:main Nov 17, 2025
123 of 129 checks passed
self,
hidden_states,
gemm_output_zero_allocator: BumpAllocator = None,
forward_batch: ForwardBatch = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better to specify the exact info we need rather than passing an entire forawrd_batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this issue will be fixed in the next PR soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants