1717 get_tp_group ,
1818 tensor_model_parallel_all_reduce ,
1919)
20+ from sglang .srt .distributed .device_communicators .pynccl_allocator import (
21+ use_symmetric_memory ,
22+ )
2023
2124if TYPE_CHECKING :
2225 from sglang .srt .configs .model_config import ModelConfig
@@ -72,6 +75,7 @@ class _DpGatheredBufferWrapper:
7275 _device : torch .device
7376 _global_dp_buffer_len : int
7477 _local_dp_buffer_len : int
78+ _is_max_padding : bool
7579 _global_num_tokens : Optional [List [int ]]
7680
7781 @classmethod
@@ -85,27 +89,37 @@ def set_dp_buffer_len(
8589 cls ,
8690 global_dp_buffer_len : int ,
8791 local_dp_buffer_len : int ,
92+ is_max_padding : bool ,
8893 global_num_tokens : Optional [List [int ]] = None ,
8994 ):
9095 cls ._global_dp_buffer_len = global_dp_buffer_len
9196 cls ._local_dp_buffer_len = local_dp_buffer_len
97+ cls ._is_max_padding = is_max_padding
9298 cls ._global_num_tokens = global_num_tokens
9399
94100 @classmethod
95101 def get_global_dp_buffer (cls ) -> torch .Tensor :
96- return torch .empty (
97- (cls ._global_dp_buffer_len , cls ._hidden_size ),
98- dtype = cls ._dtype ,
99- device = cls ._device ,
100- )
102+ with use_symmetric_memory (get_tp_group ()) as sm :
103+ buffer = torch .empty (
104+ (cls ._global_dp_buffer_len , cls ._hidden_size ),
105+ dtype = cls ._dtype ,
106+ device = cls ._device ,
107+ )
108+ sm .tag (buffer )
109+ return buffer
101110
102111 @classmethod
103112 def get_local_dp_buffer (cls ) -> torch .Tensor :
104- return torch .empty (
105- (cls ._local_dp_buffer_len , cls ._hidden_size ),
106- dtype = cls ._dtype ,
107- device = cls ._device ,
108- )
113+ with use_symmetric_memory (
114+ get_tp_group (), disabled = not cls ._is_max_padding
115+ ) as sm :
116+ buffer = torch .empty (
117+ (cls ._local_dp_buffer_len , cls ._hidden_size ),
118+ dtype = cls ._dtype ,
119+ device = cls ._device ,
120+ )
121+ sm .tag (buffer )
122+ return buffer
109123
110124 @classmethod
111125 def get_global_dp_buffer_len (cls ) -> int :
@@ -119,14 +133,19 @@ def get_local_dp_buffer_len(cls) -> int:
119133 def get_dp_global_num_tokens (cls ) -> List [int ]:
120134 return cls ._global_num_tokens
121135
136+ @classmethod
137+ def is_max_padding (cls ) -> bool :
138+ return cls ._is_max_padding
139+
122140
123141def set_dp_buffer_len (
124142 global_dp_buffer_len : int ,
125143 local_dp_buffer_len : int ,
144+ is_max_padding : bool ,
126145 global_num_tokens : Optional [List [int ]] = None ,
127146):
128147 _DpGatheredBufferWrapper .set_dp_buffer_len (
129- global_dp_buffer_len , local_dp_buffer_len , global_num_tokens
148+ global_dp_buffer_len , local_dp_buffer_len , is_max_padding , global_num_tokens
130149 )
131150
132151
@@ -150,6 +169,10 @@ def get_dp_global_num_tokens() -> List[int]:
150169 return _DpGatheredBufferWrapper .get_dp_global_num_tokens ()
151170
152171
172+ def is_max_padding () -> bool :
173+ return _DpGatheredBufferWrapper .is_max_padding ()
174+
175+
153176def compute_dp_attention_world_info (enable_dp_attention , tp_rank , tp_size , dp_size ):
154177 if not enable_dp_attention :
155178 return tp_rank , tp_size , 0
@@ -408,7 +431,10 @@ def _dp_gather_via_all_gather(
408431 scattered_local_tokens = local_tokens .tensor_split (get_attention_tp_size ())[
409432 get_attention_tp_rank ()
410433 ]
411- get_attention_tp_group ().reduce_scatter_tensor (scattered_local_tokens , local_tokens )
434+ if get_attention_tp_size () > 1 :
435+ get_attention_tp_group ().reduce_scatter_tensor (
436+ scattered_local_tokens , local_tokens
437+ )
412438 get_tp_group ().all_gather_into_tensor (global_tokens , scattered_local_tokens )
413439
414440
@@ -467,7 +493,7 @@ def dp_scatter(
467493
468494
469495def dp_reduce_scatter_tensor (output : torch .Tensor , input : torch .Tensor ):
470- if get_tensor_model_parallel_world_size () == get_attention_dp_size () :
496+ if get_attention_tp_size () == 1 :
471497 get_tp_group ().reduce_scatter_tensor (output , input )
472498 else :
473499 scattered_local_tokens = input .tensor_split (
0 commit comments