Skip to content

Commit 1e3337c

Browse files
committed
Register allgather/reducescatter buffers with symm memory
1 parent a3b810e commit 1e3337c

File tree

15 files changed

+186
-69
lines changed

15 files changed

+186
-69
lines changed

benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import argparse
33
import json
44
import time
5-
from datetime import datetime
65
from contextlib import nullcontext
6+
from datetime import datetime
77
from typing import Any, Dict, List, Tuple, TypedDict
88

99
import ray

python/sglang/srt/distributed/device_communicators/pynccl_allocator.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.cuda.memory import CUDAPluggableAllocator
66

77
from sglang.srt.distributed.parallel_state import GroupCoordinator
8-
from sglang.srt.managers.schedule_batch import global_server_args_dict
98

109
nccl_allocator_source = """
1110
#include <nccl.h>
@@ -28,13 +27,21 @@
2827
_allocator = None
2928
_mem_pool = None
3029
_registered_base_addrs = set()
30+
_registered_tensor_addrs = set()
3131
_graph_pool_id = None
3232

3333

3434
def is_symmetric_memory_enabled():
35+
# Import here to avoid circular import
36+
from sglang.srt.managers.schedule_batch import global_server_args_dict
37+
3538
return global_server_args_dict["enable_symm_mem"]
3639

3740

41+
def is_symmetric_memory_tensor(tensor: torch.Tensor):
42+
return tensor.untyped_storage().data_ptr() in _registered_tensor_addrs
43+
44+
3845
def set_graph_pool_id(graph_pool_id):
3946
global _graph_pool_id
4047
_graph_pool_id = graph_pool_id
@@ -64,8 +71,17 @@ def get_nccl_mem_pool():
6471

6572

6673
class use_symmetric_memory:
67-
def __init__(self, group_coordinator: GroupCoordinator):
68-
if not is_symmetric_memory_enabled():
74+
def __init__(
75+
self,
76+
group_coordinator: GroupCoordinator,
77+
disabled: bool = False,
78+
):
79+
self.disabled = (
80+
disabled
81+
or not is_symmetric_memory_enabled()
82+
or group_coordinator.world_size == 1
83+
)
84+
if self.disabled:
6985
self.group_coordinator = None
7086
self._mem_pool_ctx = None
7187
self.is_graph_capture = None
@@ -79,7 +95,7 @@ def __init__(self, group_coordinator: GroupCoordinator):
7995
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
8096

8197
def __enter__(self):
82-
if not is_symmetric_memory_enabled():
98+
if self.disabled:
8399
return self
84100
assert (
85101
self.group_coordinator.pynccl_comm is not None
@@ -102,12 +118,13 @@ def __enter__(self):
102118
return self
103119

104120
def tag(self, tensor: torch.Tensor):
105-
if not is_symmetric_memory_enabled():
121+
if self.disabled:
106122
return
107-
tensor.symmetric_memory = True
123+
global _registered_tensor_addrs
124+
_registered_tensor_addrs.add(tensor.untyped_storage().data_ptr())
108125

109126
def __exit__(self, exc_type, exc_val, exc_tb):
110-
if not is_symmetric_memory_enabled():
127+
if self.disabled:
111128
return
112129
global _registered_base_addrs
113130
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)

python/sglang/srt/distributed/parallel_state.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,13 @@ def __init__(
270270
from sglang.srt.distributed.device_communicators.pynccl import (
271271
PyNcclCommunicator,
272272
)
273+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
274+
is_symmetric_memory_tensor,
275+
use_symmetric_memory,
276+
)
273277

278+
self.is_symmetric_memory_tensor = is_symmetric_memory_tensor
279+
self.use_symmetric_memory = use_symmetric_memory
274280
if is_hip():
275281
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
276282
QuickAllReduce,
@@ -499,11 +505,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
499505
if self.npu_communicator is not None and not self.npu_communicator.disabled:
500506
return self.npu_communicator.all_reduce(input_)
501507

502-
if (
503-
self.pynccl_comm is not None
504-
and hasattr(input_, "symmetric_memory")
505-
and input_.symmetric_memory
506-
):
508+
if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_):
507509
with self.pynccl_comm.change_state(
508510
enable=True, stream=torch.cuda.current_stream()
509511
):
@@ -569,9 +571,23 @@ def reduce_scatter_tensor(
569571
self,
570572
output: torch.Tensor,
571573
input: torch.Tensor,
572-
) -> None:
573-
# TODO(ch-wan): support other backends
574-
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
574+
) -> torch.Tensor:
575+
pynccl_comm = self.pynccl_comm
576+
if pynccl_comm is not None and (
577+
not pynccl_comm.disabled
578+
or (
579+
self.is_symmetric_memory_tensor(output)
580+
and self.is_symmetric_memory_tensor(input)
581+
)
582+
):
583+
with pynccl_comm.change_state(
584+
enable=True, stream=torch.cuda.current_stream()
585+
):
586+
pynccl_comm.reduce_scatter(output, input)
587+
else:
588+
torch.distributed.reduce_scatter_tensor(
589+
output, input, group=self.device_group
590+
)
575591
return output
576592

577593
def reduce_scatter(
@@ -618,8 +634,17 @@ def reduce_scatterv(
618634

619635
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
620636
pynccl_comm = self.pynccl_comm
621-
if pynccl_comm is not None and not pynccl_comm.disabled:
622-
pynccl_comm.all_gather(output, input)
637+
if pynccl_comm is not None and (
638+
not pynccl_comm.disabled
639+
or (
640+
self.is_symmetric_memory_tensor(output)
641+
and self.is_symmetric_memory_tensor(input)
642+
)
643+
):
644+
with pynccl_comm.change_state(
645+
enable=True, stream=torch.cuda.current_stream()
646+
):
647+
pynccl_comm.all_gather(output, input)
623648
else:
624649
torch.distributed.all_gather_into_tensor(
625650
output, input, group=self.device_group
@@ -681,9 +706,11 @@ def all_gather(
681706
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
682707
output_size = (input_size[0] * world_size,) + input_size[1:]
683708
# Allocate output tensor.
684-
output_tensor = torch.empty(
685-
output_size, dtype=input_.dtype, device=input_.device
686-
)
709+
with self.use_symmetric_memory(self) as sm:
710+
output_tensor = torch.empty(
711+
output_size, dtype=input_.dtype, device=input_.device
712+
)
713+
sm.tag(output_tensor)
687714

688715
# All-gather.
689716
if input_.is_cpu and is_shm_available(

python/sglang/srt/layers/communicator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
from sglang.srt.distributed import (
2323
get_tensor_model_parallel_world_size,
24+
get_tp_group,
2425
tensor_model_parallel_all_reduce,
2526
)
27+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
28+
use_symmetric_memory,
29+
)
2630
from sglang.srt.layers.dp_attention import (
2731
attn_tp_all_gather_into_tensor,
2832
attn_tp_reduce_scatter_tensor,
@@ -430,7 +434,13 @@ def _gather_hidden_states_and_residual(
430434
use_layer_norm_before_gather = context.attn_tp_size == 1
431435
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
432436
residual = hidden_states
433-
hidden_states = layernorm(hidden_states)
437+
with use_symmetric_memory(
438+
get_tp_group(),
439+
disabled=not forward_batch.dp_padding_mode.is_max_len(),
440+
) as sm:
441+
hidden_states = layernorm(hidden_states)
442+
sm.tag(hidden_states)
443+
434444
hidden_states, local_hidden_states = (
435445
get_global_dp_buffer(),
436446
hidden_states,

python/sglang/srt/layers/dp_attention.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
get_tp_group,
1818
tensor_model_parallel_all_reduce,
1919
)
20+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21+
use_symmetric_memory,
22+
)
2023

2124
if TYPE_CHECKING:
2225
from sglang.srt.configs.model_config import ModelConfig
@@ -72,6 +75,7 @@ class _DpGatheredBufferWrapper:
7275
_device: torch.device
7376
_global_dp_buffer_len: int
7477
_local_dp_buffer_len: int
78+
_is_max_padding: bool
7579
_global_num_tokens: Optional[List[int]]
7680

7781
@classmethod
@@ -85,27 +89,37 @@ def set_dp_buffer_len(
8589
cls,
8690
global_dp_buffer_len: int,
8791
local_dp_buffer_len: int,
92+
is_max_padding: bool,
8893
global_num_tokens: Optional[List[int]] = None,
8994
):
9095
cls._global_dp_buffer_len = global_dp_buffer_len
9196
cls._local_dp_buffer_len = local_dp_buffer_len
97+
cls._is_max_padding = is_max_padding
9298
cls._global_num_tokens = global_num_tokens
9399

94100
@classmethod
95101
def get_global_dp_buffer(cls) -> torch.Tensor:
96-
return torch.empty(
97-
(cls._global_dp_buffer_len, cls._hidden_size),
98-
dtype=cls._dtype,
99-
device=cls._device,
100-
)
102+
with use_symmetric_memory(get_tp_group()) as sm:
103+
buffer = torch.empty(
104+
(cls._global_dp_buffer_len, cls._hidden_size),
105+
dtype=cls._dtype,
106+
device=cls._device,
107+
)
108+
sm.tag(buffer)
109+
return buffer
101110

102111
@classmethod
103112
def get_local_dp_buffer(cls) -> torch.Tensor:
104-
return torch.empty(
105-
(cls._local_dp_buffer_len, cls._hidden_size),
106-
dtype=cls._dtype,
107-
device=cls._device,
108-
)
113+
with use_symmetric_memory(
114+
get_tp_group(), disabled=not cls._is_max_padding
115+
) as sm:
116+
buffer = torch.empty(
117+
(cls._local_dp_buffer_len, cls._hidden_size),
118+
dtype=cls._dtype,
119+
device=cls._device,
120+
)
121+
sm.tag(buffer)
122+
return buffer
109123

110124
@classmethod
111125
def get_global_dp_buffer_len(cls) -> int:
@@ -119,14 +133,19 @@ def get_local_dp_buffer_len(cls) -> int:
119133
def get_dp_global_num_tokens(cls) -> List[int]:
120134
return cls._global_num_tokens
121135

136+
@classmethod
137+
def is_max_padding(cls) -> bool:
138+
return cls._is_max_padding
139+
122140

123141
def set_dp_buffer_len(
124142
global_dp_buffer_len: int,
125143
local_dp_buffer_len: int,
144+
is_max_padding: bool,
126145
global_num_tokens: Optional[List[int]] = None,
127146
):
128147
_DpGatheredBufferWrapper.set_dp_buffer_len(
129-
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
148+
global_dp_buffer_len, local_dp_buffer_len, is_max_padding, global_num_tokens
130149
)
131150

132151

@@ -150,6 +169,10 @@ def get_dp_global_num_tokens() -> List[int]:
150169
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151170

152171

172+
def is_max_padding() -> bool:
173+
return _DpGatheredBufferWrapper.is_max_padding()
174+
175+
153176
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154177
if not enable_dp_attention:
155178
return tp_rank, tp_size, 0
@@ -408,7 +431,10 @@ def _dp_gather_via_all_gather(
408431
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
409432
get_attention_tp_rank()
410433
]
411-
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
434+
if get_attention_tp_size() > 1:
435+
get_attention_tp_group().reduce_scatter_tensor(
436+
scattered_local_tokens, local_tokens
437+
)
412438
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
413439

414440

@@ -467,7 +493,7 @@ def dp_scatter(
467493

468494

469495
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
470-
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
496+
if get_attention_tp_size() == 1:
471497
get_tp_group().reduce_scatter_tensor(output, input)
472498
else:
473499
scattered_local_tokens = input.tensor_split(

python/sglang/srt/layers/linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
13011301
# It does not support additional parameters.
13021302
param.load_row_parallel_weight(loaded_weight)
13031303

1304-
def forward(self, input_, skip_all_reduce=False):
1304+
def forward(self, input_, skip_all_reduce=False, disable_symmetric_memory=True):
13051305
if self.input_is_parallel:
13061306
input_parallel = input_
13071307
else:
@@ -1315,7 +1315,9 @@ def forward(self, input_, skip_all_reduce=False):
13151315
# Only fuse bias add into GEMM for rank 0 (this ensures that
13161316
# bias will not get added more than once in TP>1 case)
13171317
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1318-
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1318+
with use_symmetric_memory(
1319+
parallel_state.get_tp_group(), disabled=disable_symmetric_memory
1320+
) as sm:
13191321
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
13201322
sm.tag(output_parallel)
13211323

python/sglang/srt/layers/logits_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def compute_dp_attention_metadata(self):
194194
set_dp_buffer_len(
195195
self.global_dp_buffer_len,
196196
self.dp_local_num_tokens,
197+
False,
197198
self.global_num_tokens_for_logprob_cpu,
198199
)
199200

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
get_moe_expert_parallel_world_size,
1212
get_moe_tensor_parallel_rank,
1313
get_moe_tensor_parallel_world_size,
14-
get_tp_group,
1514
tensor_model_parallel_all_reduce,
1615
)
17-
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18-
use_symmetric_memory,
19-
)
2016
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
2117
from sglang.srt.layers.moe import (
2218
MoeRunnerConfig,
@@ -812,15 +808,12 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
812808
raise NotImplementedError()
813809

814810
# Matrix multiply.
815-
with use_symmetric_memory(get_tp_group()) as sm:
816-
817-
final_hidden_states = self.quant_method.apply(
818-
layer=self,
819-
x=hidden_states,
820-
topk_output=topk_output,
821-
moe_runner_config=self.moe_runner_config,
822-
)
823-
sm.tag(final_hidden_states)
811+
final_hidden_states = self.quant_method.apply(
812+
layer=self,
813+
x=hidden_states,
814+
topk_output=topk_output,
815+
moe_runner_config=self.moe_runner_config,
816+
)
824817

825818
final_hidden_states = final_hidden_states[
826819
..., :origin_hidden_states_dim

0 commit comments

Comments
 (0)