Skip to content

Commit 2dbf6c6

Browse files
trevor-mMahmoudAshraf97
authored andcommitted
Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (sgl-project#7667)
1 parent 4b8c4d9 commit 2dbf6c6

File tree

16 files changed

+360
-52
lines changed

16 files changed

+360
-52
lines changed

python/sglang/srt/distributed/device_communicators/pynccl.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,11 @@ def all_reduce(
148148
)
149149

150150
def all_gather(
151-
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
151+
self,
152+
output_tensor: torch.Tensor,
153+
input_tensor: torch.Tensor,
154+
stream=None,
155+
sizes: Optional[list[int]] = None,
152156
):
153157
if self.disabled:
154158
return
@@ -161,21 +165,41 @@ def all_gather(
161165
)
162166
if stream is None:
163167
stream = self.stream
164-
self.nccl.ncclAllGather(
165-
buffer_type(input_tensor.data_ptr()),
166-
buffer_type(output_tensor.data_ptr()),
167-
input_tensor.numel(),
168-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
169-
self.comm,
170-
cudaStream_t(stream.cuda_stream),
171-
)
168+
169+
if sizes is not None:
170+
split_offset = 0
171+
172+
self.nccl.ncclGroupStart()
173+
for root, split_size in enumerate(sizes):
174+
dst_slice = output_tensor[split_offset : split_offset + split_size]
175+
self.nccl.ncclBroadcast(
176+
buffer_type(input_tensor.data_ptr()),
177+
buffer_type(dst_slice.data_ptr()),
178+
dst_slice.numel(),
179+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
180+
root,
181+
self.comm,
182+
cudaStream_t(stream.cuda_stream),
183+
)
184+
split_offset += split_size
185+
self.nccl.ncclGroupEnd()
186+
else:
187+
self.nccl.ncclAllGather(
188+
buffer_type(input_tensor.data_ptr()),
189+
buffer_type(output_tensor.data_ptr()),
190+
input_tensor.numel(),
191+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
192+
self.comm,
193+
cudaStream_t(stream.cuda_stream),
194+
)
172195

173196
def reduce_scatter(
174197
self,
175198
output_tensor: torch.Tensor,
176199
input_tensor: torch.Tensor,
177200
op: ReduceOp = ReduceOp.SUM,
178201
stream=None,
202+
sizes: Optional[list[int]] = None,
179203
):
180204
if self.disabled:
181205
return
@@ -188,15 +212,35 @@ def reduce_scatter(
188212
)
189213
if stream is None:
190214
stream = self.stream
191-
self.nccl.ncclReduceScatter(
192-
buffer_type(input_tensor.data_ptr()),
193-
buffer_type(output_tensor.data_ptr()),
194-
output_tensor.numel(),
195-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
196-
ncclRedOpTypeEnum.from_torch(op),
197-
self.comm,
198-
cudaStream_t(stream.cuda_stream),
199-
)
215+
216+
if sizes is not None:
217+
split_offset = 0
218+
self.nccl.ncclGroupStart()
219+
for root, split_size in enumerate(sizes):
220+
chunk = input_tensor[split_offset : split_offset + split_size, ...]
221+
222+
self.nccl.ncclReduce(
223+
buffer_type(chunk.data_ptr()),
224+
buffer_type(output_tensor.data_ptr()),
225+
chunk.numel(),
226+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
227+
ncclRedOpTypeEnum.from_torch(op),
228+
root,
229+
self.comm,
230+
cudaStream_t(stream.cuda_stream),
231+
)
232+
split_offset += split_size
233+
self.nccl.ncclGroupEnd()
234+
else:
235+
self.nccl.ncclReduceScatter(
236+
buffer_type(input_tensor.data_ptr()),
237+
buffer_type(output_tensor.data_ptr()),
238+
output_tensor.numel(),
239+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
240+
ncclRedOpTypeEnum.from_torch(op),
241+
self.comm,
242+
cudaStream_t(stream.cuda_stream),
243+
)
200244

201245
def send(self, tensor: torch.Tensor, dst: int, stream=None):
202246
if self.disabled:
@@ -266,6 +310,12 @@ def register_comm_window_raw(self, ptr: int, size: int):
266310
def deregister_comm_window(self, window):
267311
return self.nccl.ncclCommWindowDeregister(self.comm, window)
268312

313+
def group_start(self):
314+
self.nccl.ncclGroupStart()
315+
316+
def group_end(self):
317+
self.nccl.ncclGroupEnd()
318+
269319
@contextmanager
270320
def change_state(
271321
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None

python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,26 @@ class NCCLLibrary:
206206
cudaStream_t,
207207
],
208208
),
209+
# ncclResult_t ncclReduce(
210+
# const void* sendbuff, void* recvbuff, size_t count,
211+
# ncclDataType_t datatype, ncclRedOp_t op, int root,
212+
# ncclComm_t comm, cudaStream_t stream);
213+
# note that cudaStream_t is a pointer type, so the last argument
214+
# is a pointer
215+
Function(
216+
"ncclReduce",
217+
ncclResult_t,
218+
[
219+
buffer_type,
220+
buffer_type,
221+
ctypes.c_size_t,
222+
ncclDataType_t,
223+
ncclRedOp_t,
224+
ctypes.c_int,
225+
ncclComm_t,
226+
cudaStream_t,
227+
],
228+
),
209229
# ncclResult_t ncclReduceScatter(
210230
# const void* sendbuff, void* recvbuff, size_t count,
211231
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
@@ -278,6 +298,10 @@ class NCCLLibrary:
278298
# it is better not to call it at all.
279299
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
280300
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
301+
# ncclResult_t ncclGroupStart();
302+
Function("ncclGroupStart", ncclResult_t, []),
303+
# ncclResult_t ncclGroupEnd();
304+
Function("ncclGroupEnd", ncclResult_t, []),
281305
]
282306

283307
exported_functions_symm_mem = [
@@ -400,6 +424,28 @@ def ncclAllReduce(
400424
)
401425
)
402426

427+
def ncclReduce(
428+
self,
429+
sendbuff: buffer_type,
430+
recvbuff: buffer_type,
431+
count: int,
432+
datatype: int,
433+
op: int,
434+
root: int,
435+
comm: ncclComm_t,
436+
stream: cudaStream_t,
437+
) -> None:
438+
# `datatype` actually should be `ncclDataType_t`
439+
# and `op` should be `ncclRedOp_t`
440+
# both are aliases of `ctypes.c_int`
441+
# when we pass int to a function, it will be converted to `ctypes.c_int`
442+
# by ctypes automatically
443+
self.NCCL_CHECK(
444+
self._funcs["ncclReduce"](
445+
sendbuff, recvbuff, count, datatype, op, root, comm, stream
446+
)
447+
)
448+
403449
def ncclReduceScatter(
404450
self,
405451
sendbuff: buffer_type,
@@ -499,6 +545,12 @@ def ncclCommWindowRegister(
499545
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
500546
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
501547

548+
def ncclGroupStart(self) -> None:
549+
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
550+
551+
def ncclGroupEnd(self) -> None:
552+
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
553+
502554

503555
__all__ = [
504556
"NCCLLibrary",

python/sglang/srt/distributed/parallel_state.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,39 @@ def reduce_scatter(
583583
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
584584
return output
585585

586+
def reduce_scatterv(
587+
self,
588+
input_: torch.Tensor,
589+
output: Optional[torch.Tensor] = None,
590+
sizes: Optional[List[int]] = None,
591+
) -> torch.Tensor:
592+
world_size = self.world_size
593+
pynccl_comm = self.pynccl_comm
594+
595+
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
596+
assert (
597+
pynccl_comm is not None and not pynccl_comm.disabled
598+
), "pynccl is required for reduce_scatterv"
599+
600+
if sizes is not None:
601+
assert len(sizes) == world_size
602+
assert input_.shape[0] == sum(sizes)
603+
chunk_size = sizes[self.rank_in_group]
604+
else:
605+
assert input_.shape[0] % world_size == 0
606+
chunk_size = input_.shape[0] // world_size
607+
output_shape = (chunk_size,) + input_.shape[1:]
608+
609+
if output is None:
610+
output = torch.empty(
611+
output_shape, dtype=input_.dtype, device=input_.device
612+
)
613+
else:
614+
assert output.shape == output_shape
615+
616+
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
617+
return output
618+
586619
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
587620
pynccl_comm = self.pynccl_comm
588621
if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -673,6 +706,54 @@ def all_gather(
673706
)
674707
return output_tensor
675708

709+
def all_gatherv(
710+
self,
711+
input_: Union[torch.Tensor, List[torch.Tensor]],
712+
sizes: Optional[List[int]] = None,
713+
) -> Union[torch.Tensor, List[torch.Tensor]]:
714+
"""
715+
Supports varying sizes per rank and input tensor list.
716+
`sizes`: a list of len(world_size) with the number of items per rank to gather.
717+
"""
718+
world_size = self.world_size
719+
pynccl_comm = self.pynccl_comm
720+
721+
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
722+
assert (
723+
pynccl_comm is not None and not pynccl_comm.disabled
724+
), "pynccl is required for all_gatherv"
725+
726+
def _all_gather_single(
727+
input_: torch.Tensor, sizes: Optional[List[int]] = None
728+
):
729+
input_size = input_.size()
730+
if sizes is not None:
731+
assert len(sizes) == world_size
732+
assert input_.shape[0] == sizes[self.rank_in_group]
733+
output_size = (sum(sizes),) + input_size[1:]
734+
# 'sizes' is not needed if all inputs in the same group have the same shape
735+
if all(s == sizes[0] for s in sizes):
736+
sizes = None
737+
else:
738+
output_size = (input_size[0] * world_size,) + input_size[1:]
739+
# Allocate output tensor.
740+
output_tensor = torch.empty(
741+
output_size, dtype=input_.dtype, device=input_.device
742+
)
743+
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
744+
return output_tensor
745+
746+
if isinstance(input_, torch.Tensor):
747+
return _all_gather_single(input_, sizes)
748+
749+
output_list = []
750+
pynccl_comm.group_start()
751+
for inp in input_:
752+
output_list.append(_all_gather_single(inp, sizes=sizes))
753+
pynccl_comm.group_end()
754+
755+
return output_list
756+
676757
def gather(
677758
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
678759
) -> Optional[torch.Tensor]:

python/sglang/srt/layers/communicator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
get_global_dp_buffer,
3636
get_local_dp_buffer,
3737
)
38-
from sglang.srt.layers.moe import get_moe_a2a_backend
38+
from sglang.srt.layers.moe import (
39+
get_moe_a2a_backend,
40+
should_use_flashinfer_cutlass_moe_fp4_allgather,
41+
)
3942
from sglang.srt.layers.utils import is_sm100_supported
4043
from sglang.srt.managers.schedule_batch import global_server_args_dict
4144
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -112,7 +115,11 @@ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
112115
if context.is_layer_sparse:
113116
return (
114117
ScatterMode.SCATTERED
115-
if not get_moe_a2a_backend().is_none()
118+
if (
119+
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
120+
not get_moe_a2a_backend().is_none()
121+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
122+
)
116123
else ScatterMode.FULL
117124
)
118125
else:

python/sglang/srt/layers/dp_attention.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
7272
_device: torch.device
7373
_global_dp_buffer_len: int
7474
_local_dp_buffer_len: int
75+
_global_num_tokens: Optional[List[int]]
7576

7677
@classmethod
7778
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -80,9 +81,15 @@ def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device
8081
cls._device = device
8182

8283
@classmethod
83-
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
84+
def set_dp_buffer_len(
85+
cls,
86+
global_dp_buffer_len: int,
87+
local_dp_buffer_len: int,
88+
global_num_tokens: Optional[List[int]] = None,
89+
):
8490
cls._global_dp_buffer_len = global_dp_buffer_len
8591
cls._local_dp_buffer_len = local_dp_buffer_len
92+
cls._global_num_tokens = global_num_tokens
8693

8794
@classmethod
8895
def get_global_dp_buffer(cls) -> torch.Tensor:
@@ -108,10 +115,18 @@ def get_global_dp_buffer_len(cls) -> int:
108115
def get_local_dp_buffer_len(cls) -> int:
109116
return cls._local_dp_buffer_len
110117

118+
@classmethod
119+
def get_dp_global_num_tokens(cls) -> List[int]:
120+
return cls._global_num_tokens
121+
111122

112-
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
123+
def set_dp_buffer_len(
124+
global_dp_buffer_len: int,
125+
local_dp_buffer_len: int,
126+
global_num_tokens: Optional[List[int]] = None,
127+
):
113128
_DpGatheredBufferWrapper.set_dp_buffer_len(
114-
global_dp_buffer_len, local_dp_buffer_len
129+
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
115130
)
116131

117132

@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
131146
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
132147

133148

149+
def get_dp_global_num_tokens() -> List[int]:
150+
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151+
152+
134153
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
135154
if not enable_dp_attention:
136155
return tp_rank, tp_size, 0

python/sglang/srt/layers/logits_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,11 @@ def compute_dp_attention_metadata(self):
191191
else:
192192
self.global_dp_buffer_len = self.global_dp_buffer_len
193193

194-
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
194+
set_dp_buffer_len(
195+
self.global_dp_buffer_len,
196+
self.dp_local_num_tokens,
197+
self.global_num_tokens_for_logprob_cpu,
198+
)
195199

196200

197201
class LogitsProcessor(nn.Module):

0 commit comments

Comments
 (0)