Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
import contextlib
import gc
import importlib.util
import logging
import os
import pickle
Expand All @@ -44,7 +45,7 @@
get_bool_env_var,
is_cuda_alike,
is_npu,
supports_custom_op,
run_once, supports_custom_op,
)


Expand Down Expand Up @@ -1142,10 +1143,49 @@ def init_distributed_environment(
_WORLD.world_size == torch.distributed.get_world_size()
), "world group already initialized with a different world size"

# Adapted from https://github.com/vllm-project/vllm/blob/f9c069c85e029830094ff9abb926ffbf37b7c7e7/vllm/distributed/parallel_state.py#L940
PPLX_DID_INIT: bool = False

# Adapted from https://github.com/vllm-project/vllm/blob/f9c069c85e029830094ff9abb926ffbf37b7c7e7/vllm/distributed/parallel_state.py#L944
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None

if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pplx_init function correctly uses @run_once and checks for pplx_kernels. However, the error handling for NVSHMEM initialization could be more specific. Catching a generic Exception might hide specific issues that could be handled or logged differently.

Consider catching more specific exceptions if known (e.g., RuntimeError from pplx_kernels or torch.cuda.CudaError) or at least logging the type of exception in the error message for better diagnostics.


# Adapted from https://github.com/vllm-project/vllm/blob/f9c069c85e029830094ff9abb926ffbf37b7c7e7/vllm/distributed/parallel_state.py#L968
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from sglang.srt.layers.moe.ep_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from sglang.srt.layers.moe.ep_moe.layer import _all_to_all_cache inside pplx_finalize is a local import. While this works, it's generally preferred to have imports at the top of the file for clarity and to avoid potential circular import issues, though in this specific @run_once context, it might be acceptable to delay the import until first use.

If _all_to_all_cache is lightweight to import or frequently used, consider moving the import to the top. If it's heavy or has specific initialization dependencies tied to PPLX, keeping it local might be justified, but it's worth a comment explaining why if so.


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_pplx_moe: bool = False,
backend: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -1221,10 +1261,13 @@ def initialize_model_parallel(
group_name="pp",
)

if enable_pplx_moe:
pplx_init(get_world_group().local_rank, world_size)

def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_pplx_moe: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
Expand All @@ -1234,7 +1277,7 @@ def ensure_model_parallel_initialized(
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend
tensor_model_parallel_size, pipeline_model_parallel_size, enable_pplx_moe, backend
)
return

Expand Down Expand Up @@ -1293,10 +1336,12 @@ def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group


def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP

pplx_finalize()

if _TP:
_TP.destroy()
_TP = None
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if global_server_args_dict["enable_deepep_moe"]
if global_server_args_dict["enable_deepep_moe"] or global_server_args_dict["enable_pplx_moe"]
else ScatterMode.FULL
)
else:
Expand Down
Loading