Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.cuda.memory import CUDAPluggableAllocator

from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict

nccl_allocator_source = """
#include <nccl.h>
Expand All @@ -29,12 +28,26 @@
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_cached_pool_snapshot = None


def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.managers.schedule_batch import global_server_args_dict

return global_server_args_dict["enable_symm_mem"]


def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
for block in segment["blocks"]:
if block["address"] == tensor.untyped_storage().data_ptr():
return True
return False


def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
Expand Down Expand Up @@ -64,8 +77,17 @@ def get_nccl_mem_pool():


class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
def __init__(
self,
group_coordinator: GroupCoordinator,
disabled: bool = False,
):
self.disabled = (
disabled
or not is_symmetric_memory_enabled()
or group_coordinator.world_size == 1
)
if self.disabled:
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
Expand All @@ -79,7 +101,7 @@ def __init__(self, group_coordinator: GroupCoordinator):
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")

def __enter__(self):
if not is_symmetric_memory_enabled():
if self.disabled:
return self
assert (
self.group_coordinator.pynccl_comm is not None
Expand All @@ -101,17 +123,14 @@ def __enter__(self):
self._mem_pool_ctx.__enter__()
return self

def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True

def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
if self.disabled:
return
global _cached_pool_snapshot
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
_cached_pool_snapshot = get_nccl_mem_pool().snapshot()
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
Expand Down
52 changes: 39 additions & 13 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,13 @@ def __init__(
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_tensor,
use_symmetric_memory,
)

self.is_symmetric_memory_tensor = is_symmetric_memory_tensor
self.use_symmetric_memory = use_symmetric_memory
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
Expand Down Expand Up @@ -509,11 +515,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_):
with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
Expand Down Expand Up @@ -579,9 +581,23 @@ def reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and (
not pynccl_comm.disabled
or (
self.is_symmetric_memory_tensor(output)
and self.is_symmetric_memory_tensor(input)
)
):
with pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output

def reduce_scatter(
Expand Down Expand Up @@ -628,8 +644,17 @@ def reduce_scatterv(

def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output, input)
if pynccl_comm is not None and (
not pynccl_comm.disabled
or (
self.is_symmetric_memory_tensor(output)
and self.is_symmetric_memory_tensor(input)
)
):
with pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
Expand Down Expand Up @@ -691,9 +716,10 @@ def all_gather(
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
with self.use_symmetric_memory(self):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

# All-gather.
if input_.is_cpu and is_shm_available(
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
Expand Down Expand Up @@ -469,7 +473,12 @@ def _gather_hidden_states_and_residual(
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = layernorm(hidden_states)
with use_symmetric_memory(
get_tp_group(),
disabled=not forward_batch.dp_padding_mode.is_max_len(),
):
hidden_states = layernorm(hidden_states)

hidden_states, local_hidden_states = (
get_global_dp_buffer(),
hidden_states,
Expand Down
48 changes: 35 additions & 13 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)

if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
Expand Down Expand Up @@ -72,6 +75,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_dp_max_padding: bool = True
_global_num_tokens: Optional[List[int]]

@classmethod
Expand All @@ -85,27 +89,33 @@ def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._dp_max_padding = dp_max_padding
cls._global_num_tokens = global_num_tokens

@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group()):
buffer = torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer

@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
buffer = torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer

@classmethod
def get_global_dp_buffer_len(cls) -> int:
Expand All @@ -119,14 +129,19 @@ def get_local_dp_buffer_len(cls) -> int:
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens

@classmethod
def is_dp_max_padding(cls) -> bool:
return cls._dp_max_padding


def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens
)


Expand All @@ -150,6 +165,10 @@ def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()


def is_dp_max_padding() -> bool:
return _DpGatheredBufferWrapper.is_dp_max_padding()


def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
Expand Down Expand Up @@ -408,7 +427,10 @@ def _dp_gather_via_all_gather(
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
get_attention_tp_rank()
]
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
if get_attention_tp_size() > 1:
get_attention_tp_group().reduce_scatter_tensor(
scattered_local_tokens, local_tokens
)
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)


Expand Down Expand Up @@ -467,7 +489,7 @@ def dp_scatter(


def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
if get_attention_tp_size() == 1:
get_tp_group().reduce_scatter_tensor(output, input)
else:
scattered_local_tokens = input.tensor_split(
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
get_tp_group,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_dp_max_padding
from sglang.srt.layers.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
Expand Down Expand Up @@ -1316,9 +1317,8 @@ def forward(self, input_, skip_all_reduce=False):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()):
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)

if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def compute_dp_attention_metadata(self):
set_dp_buffer_len(
self.global_dp_buffer_len,
self.dp_local_num_tokens,
False,
self.global_num_tokens_for_logprob_cpu,
)

Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def cutlass_fused_experts_fp8(
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.

Expand Down Expand Up @@ -94,7 +95,7 @@ def cutlass_fused_experts_fp8(
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.

output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created.
Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.

Expand Down Expand Up @@ -202,9 +203,11 @@ def cutlass_fused_experts_fp8(
workspace,
)

result = torch.empty((m, k), device=device, dtype=out_dtype)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
return result
if output is None:
output = torch.empty((m, k), device=device, dtype=out_dtype)

apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))
return output


FLOAT4_E2M1_MAX = 6.0
Expand Down
Loading
Loading