Skip to content

Register allgather/reducescatter buffers with symm memory#8934

Closed
nvcastet wants to merge 1 commit intosgl-project:mainfrom
nvcastet:dp_rebased
Closed

Register allgather/reducescatter buffers with symm memory#8934
nvcastet wants to merge 1 commit intosgl-project:mainfrom
nvcastet:dp_rebased

Conversation

@nvcastet
Copy link
Copy Markdown
Collaborator

@nvcastet nvcastet commented Aug 7, 2025

Motivation

Speedup AllGather and ReduceScatter for max-throughput configs (DP-attention)

Benchmark & Profiling

E2E Speedup: 4.7%

Baseline

Server:

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8 --dp 8 --enable-dp-attention --chunked-prefill-size 16384 --mem-fraction-static 0.85 --max-running-requests 4096 --stream-interval 5 --enable-dp-lm-head --attention-backend trtllm_mla --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 --disable-radix-cache

Client:

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1024 --random-input 1024 --random-output 2048 --random-range-ratio 1 --warmup-request 1024
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     1024
Benchmark duration (s):                  115.62
Total input tokens:                      1048576
Total generated tokens:                  2097152
Total generated tokens (retokenized):    2090152
Request throughput (req/s):              8.86
Input token throughput (tok/s):          9068.80
Output token throughput (tok/s):         18137.61
Total token throughput (tok/s):          27206.41
Concurrency:                             1022.52
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   115457.41
Median E2E Latency (ms):                 115462.20
---------------Time to First Token----------------
Mean TTFT (ms):                          10490.22
Median TTFT (ms):                        10378.01
P99 TTFT (ms):                           18484.57
---------------Inter-Token Latency----------------
Mean ITL (ms):                           51.28
Median ITL (ms):                         47.05
P95 ITL (ms):                            52.15
P99 ITL (ms):                            54.71
Max ITL (ms):                            3283.82
==================================================

With this PR

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8 --dp 8 --enable-dp-attention --chunked-prefill-size 16384 --mem-fraction-static 0.85 --max-running-requests 4096 --stream-interval 5 --enable-dp-lm-head --attention-backend trtllm_mla --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 --disable-radix-cache  --enable-symm-mem
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     1024
Benchmark duration (s):                  110.42
Total input tokens:                      1048576
Total generated tokens:                  2097152
Total generated tokens (retokenized):    2090080
Request throughput (req/s):              9.27
Input token throughput (tok/s):          9496.07
Output token throughput (tok/s):         18992.14
Total token throughput (tok/s):          28488.20
Concurrency:                             1022.55
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   110265.53
Median E2E Latency (ms):                 110274.49
---------------Time to First Token----------------
Mean TTFT (ms):                          10447.13
Median TTFT (ms):                        10271.72
P99 TTFT (ms):                           18478.40
---------------Inter-Token Latency----------------
Mean ITL (ms):                           48.76
Median ITL (ms):                         44.57
P95 ITL (ms):                            49.59
P99 ITL (ms):                            52.35
Max ITL (ms):                            3298.69
==================================================

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @nvcastet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented changes to optimize distributed communication operations, specifically AllGather and ReduceScatter, by registering their buffers with symmetric memory. This optimization is particularly beneficial for maximum-throughput configurations like data-parallel attention, and my benchmarks show a 4.7% end-to-end speedup.

Highlights

  • Enhanced Symmetric Memory Management: I've refined the symmetric memory allocation mechanism by introducing a new set to track registered tensor data pointers and a utility function to check if a tensor resides in symmetric memory. The use_symmetric_memory context manager now offers more granular control with disabled and disable_war flags, allowing for conditional application and specific workarounds for older PyTorch versions.
  • Optimized Distributed Operations: I've updated the core distributed communication functions (all_reduce, reduce_scatter_tensor, _all_gather_into_tensor) to intelligently leverage symmetric memory when available, falling back to standard PyTorch distributed operations otherwise. This ensures that memory-registered tensors benefit from faster communication paths.
  • Strategic Symmetric Memory Integration: I've strategically applied symmetric memory registration across various model layers, including layernorm outputs, gathered_buffer allocations, linear layer outputs, and embedding lookups. This integration is often conditional, enabling symmetric memory only when it provides a performance benefit (e.g., when tensor parallelism is active or specific padding modes are used), and can be explicitly disabled via new parameters.
  • Refined Attention and MoE Logic: I've added checks to attention-related operations (reduce_scatter_tensor, all_gather_into_tensor) to ensure they only execute when tensor parallelism is greater than one, preventing redundant operations. Notably, I've also removed symmetric memory usage from MoE layer forward passes, indicating a targeted application of this optimization.
  • Performance Validation: The changes are backed by benchmark results demonstrating a 4.7% end-to-end speedup, validating the effectiveness of registering allgather/reducescatter buffers with symmetric memory for high-throughput scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations for AllGather and ReduceScatter operations by leveraging symmetric memory, which shows a notable performance improvement. The implementation is generally solid, replacing monkey-patching with a cleaner tracking mechanism for symmetric memory tensors and adding flags for better control. However, I've identified a potential issue in python/sglang/srt/models/deepseek_v2.py where a tensor allocated in symmetric memory is not being tagged correctly. This would cause the subsequent collective communication to fall back to a less optimal path, negating the intended performance gain. Addressing this should ensure the full benefit of the optimization is realized.

Comment on lines 513 to 488
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The final_hidden_states_out tensor is allocated within the use_symmetric_memory context, but it is not tagged using sm.tag(). This means it won't be recognized as a symmetric memory tensor, and subsequent operations like tensor_model_parallel_all_reduce will not use the optimized pynccl path. This seems to undermine the performance benefits of using symmetric memory here.

A similar issue is present in the forward_normal method as well.

Suggested change
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=disable_symmetric_memory
) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=disable_symmetric_memory
) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
sm.tag(final_hidden_states_out)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to tag it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is tag a few lines below

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you tag the tensor outside of the with scope?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The with scope make all pytorch allocations under it come from the symmetric memory pool.
The .tag(tensor) is just book-keeping to flag a tensor that has been allocated via symmetric memory so that when we call a collective on it we select NCCL to get best perf instead of alternative custom kernels.

But I can move it there if it is clearer for people and AIs. :)

Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @nvcastet!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the input need to be in symmetric memory too?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes Good point I should probably check for that too.

Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think line 363 if get_tensor_model_parallel_world_size() == get_attention_dp_size(): is equivalent to get_attention_tp_size() == 1 since attn_tp_size = tp_size // dp_size. So we shouldn't need this change.

Maybe change line 363 to get_attention_tp_size() == 1 so its more clear?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is disable_war used here because this is outside of the cuda graph?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those allocations associated with stream zero which conflicts with the WAR for pre 2.8 pytorch.
For 2.8 and beyond, it is a noop.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to touch deepep path

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes probably not since i did not test it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could get this from a helper function to avoid needing to pass it everywhere. We have forward_batch here but we would also need the layer_communicator in order to check should_use_reduce_scatter...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree maybe adding an attribute to forward batch? but in another PR review, it was mentioned avoiding passing big object like forward_batch to know explicitly what is accessed by the function.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may refer to our recent refactor in dp_attention.py. We prepared some util functions to avoid tedious coding efforts. A similar util function can be prepared for symmetric_memory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks let me have a look.

@nvcastet nvcastet requested a review from trevor-m August 8, 2025 19:02
Comment on lines 513 to 488
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to tag it?

Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good!

@nvcastet nvcastet force-pushed the dp_rebased branch 3 times, most recently from a8050e5 to b2bdbcf Compare August 11, 2025 21:54
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changed the model forward code too much. Some changes are not intuitive (e.g., which tensor to tag, which region to use with use_symmetric_memory.

Is it possible to make it more transparent so that we get this feature without change any model forward code? (even logits_processor.py)

Comment on lines 513 to 488
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you tag the tensor outside of the with scope?

@nvcastet
Copy link
Copy Markdown
Collaborator Author

nvcastet commented Aug 13, 2025

This PR changed the model forward code too much.

Is it possible to make it more transparent so that we get this feature without change any model forward code? (even logits_processor.py)

Thanks for taking the time to review @merrymercy !
The takeaway of this PR and previous PR (#8238) is to leverage NCCL new symmetric kernels requiring memory used for communication to be allocated via a separate allocator&registration.
We could be less intrusive by making a buffer copy before the comm op instead of finding the source allocation that is used for communication but we would leave perf on the table especially for larger buffer sizes.

An other option is to ignore code sections where perf gain is small. I believe i could ignore comms inside logits_processor.py for example.

@merrymercy Could you provide some guidance on the direction to take?

As a reference, here is the MNNVL NCCL allreduce standalone perf gain using this feature:
image
Source: https://developer.nvidia.com/blog/enabling-fast-inference-and-resilient-training-with-nccl-2-27/#low-latency_kernels_with_symmetric_memory

)
else:
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
with use_symmetric_memory(get_tp_group()) as sm:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is symmetric memory efficient? I'm a little bit confused here. Why we do not check padding mode here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be always more efficient. The problem is that we don't want to allocate buffers that are not symmetric across GPU (otherwise we get undefined behavior) which can happen when we use the symmetric context manager across a big code section with DP-attention.
Here specifically the only allocation under the context manager is self.gathered_buffer which always has the same size across GPUs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may refer to our recent refactor in dp_attention.py. We prepared some util functions to avoid tedious coding efforts. A similar util function can be prepared for symmetric_memory.

Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, overall looks good just some minor questions.

@nvcastet nvcastet force-pushed the dp_rebased branch 2 times, most recently from ba38114 to f08a67f Compare August 20, 2025 15:53
@nvcastet nvcastet requested a review from trevor-m August 20, 2025 15:54
Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return False
for segment in get_nccl_mem_pool().snapshot():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great that we no longer need to tag the tensors, but I'm concerned about the efficiency of this check. Collecting the snapshot seems to do a lot and then we also have to iterate over all segments/blocks.

  1. Do you know if basic caching/memoization in this method would work using tensor.untyped_storage.data_ptr() as the key? Will the storage change from iteration to iteration or remain consistent?
  2. If that doesn't work, could you measure how long this check takes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would change, depending on what gets allocated in the pool But we could at least cache the snapshot at the exit of the context manager.

  1. Yes I could time the call.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

133 us (not cached) vs 5us (cached)
So I made the changes to cache it at context manager exit.

@nvcastet nvcastet requested a review from trevor-m August 20, 2025 20:33
@nvcastet
Copy link
Copy Markdown
Collaborator Author

@merrymercy Do you mind having another look?

@nvcastet
Copy link
Copy Markdown
Collaborator Author

nvcastet commented Sep 8, 2025

Closing in favor of #9358

@nvcastet nvcastet closed this Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants