Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 68 additions & 18 deletions python/sglang/srt/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def all_reduce(
)

def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
Expand All @@ -161,21 +165,41 @@ def all_gather(
)
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)

if sizes is not None:
split_offset = 0

self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this broadcast or allgather?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s equivalent to all gather. Each rank does a broadcast but we group them to avoid overheads. This is done to allow each rank to have a different size

buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)

def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
Expand All @@ -188,15 +212,35 @@ def reduce_scatter(
)
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)

if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset : split_offset + split_size, ...]

self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
Expand Down Expand Up @@ -266,6 +310,12 @@ def register_comm_window_raw(self, ptr: int, size: int):
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

def group_start(self):
self.nccl.ncclGroupStart()

def group_end(self):
self.nccl.ncclGroupEnd()

@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t,
],
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
Expand Down Expand Up @@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]

exported_functions_symm_mem = [
Expand Down Expand Up @@ -400,6 +424,28 @@ def ncclAllReduce(
)
)

def ncclReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduce"](
sendbuff, recvbuff, count, datatype, op, root, comm, stream
)
)

def ncclReduceScatter(
self,
sendbuff: buffer_type,
Expand Down Expand Up @@ -499,6 +545,12 @@ def ncclCommWindowRegister(
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))

def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())

def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())


__all__ = [
"NCCLLibrary",
Expand Down
81 changes: 81 additions & 0 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,39 @@ def reduce_scatter(
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output

def reduce_scatterv(
self,
input_: torch.Tensor,
output: Optional[torch.Tensor] = None,
sizes: Optional[List[int]] = None,
) -> torch.Tensor:
world_size = self.world_size
pynccl_comm = self.pynccl_comm

with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"

if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_.shape[0] % world_size == 0
chunk_size = input_.shape[0] // world_size
output_shape = (chunk_size,) + input_.shape[1:]

if output is None:
output = torch.empty(
output_shape, dtype=input_.dtype, device=input_.device
)
else:
assert output.shape == output_shape

pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
return output

def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
Expand Down Expand Up @@ -673,6 +706,54 @@ def all_gather(
)
return output_tensor

def all_gatherv(
self,
input_: Union[torch.Tensor, List[torch.Tensor]],
sizes: Optional[List[int]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size = self.world_size
pynccl_comm = self.pynccl_comm

with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"

def _all_gather_single(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sizes[self.rank_in_group]
output_size = (sum(sizes),) + input_size[1:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if all(s == sizes[0] for s in sizes):
sizes = None
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor

if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)

output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()

return output_list

def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -112,7 +115,11 @@ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if not get_moe_a2a_backend().is_none()
if (
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not get_moe_a2a_backend().is_none()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
)
else ScatterMode.FULL
)
else:
Expand Down
25 changes: 22 additions & 3 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_global_num_tokens: Optional[List[int]]

@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
Expand All @@ -80,9 +81,15 @@ def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device
cls._device = device

@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._global_num_tokens = global_num_tokens

@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
Expand All @@ -108,10 +115,18 @@ def get_global_dp_buffer_len(cls) -> int:
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len

@classmethod
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens


def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
)


Expand All @@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()


def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()


def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def compute_dp_attention_metadata(self):
else:
self.global_dp_buffer_len = self.global_dp_buffer_len

set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
set_dp_buffer_len(
self.global_dp_buffer_len,
self.dp_local_num_tokens,
self.global_num_tokens_for_logprob_cpu,
)


class LogitsProcessor(nn.Module):
Expand Down
Loading
Loading