Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False |
| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False |
| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False |
| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | False |
| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False |
| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False |
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.available = True
self.disabled = False

self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0:
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())

Expand Down Expand Up @@ -259,6 +260,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
cudaStream_t(stream.cuda_stream),
)

def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)

def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import tempfile

import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator

from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict

nccl_allocator_source = """
#include <nccl.h>
extern "C" {

void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;

}

void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}

}
"""

_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None


def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"]


def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id


def get_nccl_mem_pool():
global _allocator, _mem_pool
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
torch.utils.cpp_extension.load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=True,
is_python_module=False,
build_directory=out_dir,
)
_allocator = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool


class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")

def __enter__(self):
if not is_symmetric_memory_enabled():
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0:
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self

def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True

def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
return
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])

if self.is_graph_capture:
if self.pre_2_8_0:
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def find_nccl_library() -> str:

ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
ncclWindow_t = ctypes.c_void_p


class ncclUniqueId(ctypes.Structure):
Expand Down Expand Up @@ -279,6 +280,23 @@ class NCCLLibrary:
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]

exported_functions_symm_mem = [
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
Function(
"ncclCommWindowRegister",
ncclResult_t,
[
ncclComm_t,
buffer_type,
ctypes.c_size_t,
ctypes.POINTER(ncclWindow_t),
ctypes.c_int,
],
),
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
]

# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
Expand Down Expand Up @@ -312,7 +330,10 @@ def __init__(self, so_file: Optional[str] = None):

if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
exported_functions = NCCLLibrary.exported_functions
if hasattr(self.lib, "ncclCommWindowRegister"):
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
for func in exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
Expand All @@ -328,10 +349,14 @@ def NCCL_CHECK(self, result: ncclResult_t) -> None:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")

def ncclGetVersion(self) -> str:
def ncclGetRawVersion(self) -> int:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903
return version.value

def ncclGetVersion(self) -> str:
version_str = str(self.ncclGetRawVersion())
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
Expand Down Expand Up @@ -460,6 +485,20 @@ def ncclBroadcast(
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

def ncclCommWindowRegister(
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(
self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags
)
)
return window

def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))


__all__ = [
"NCCLLibrary",
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
self.pynccl_comm.all_reduce(input_)
return input_

outplace_all_reduce_method = None
if (
self.qr_comm is not None
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,9 @@ async def async_score(
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem:
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
Expand Down Expand Up @@ -1292,7 +1296,9 @@ def forward(self, input_, can_fuse_mlp_allreduce=False):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
43 changes: 25 additions & 18 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.base_config import (
Expand Down Expand Up @@ -626,24 +630,27 @@ def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)
with use_symmetric_memory(get_tp_group()) as sm:
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
)
sm.tag(final_hidden_states)

if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
Expand Down Expand Up @@ -464,7 +468,9 @@ def forward(self, input_):
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_multimodal",
"enable_symm_mem",
]

# Put some global args for easy access
Expand Down
Loading
Loading