From a84e7cde99e8aaea80e7956475e9fcfe14f47be1 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 16 Jun 2025 17:37:06 +0000 Subject: [PATCH 1/8] Add flashinfer cutlass MoE --- docs/references/environment_variables.md | 1 + python/sglang/srt/layers/moe/ep_moe/layer.py | 3 + .../srt/layers/moe/fused_moe_triton/layer.py | 85 ++++++++++++++++--- .../srt/layers/quantization/modelopt_quant.py | 62 +++++++++++++- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/deepseek_v2.py | 5 ++ python/sglang/srt/server_args.py | 16 +++- 7 files changed, 159 insertions(+), 14 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 2ce931b0343c..aeb6406175e6 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -60,6 +60,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` | | `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` | | `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` | +| `SGLANG_FLASHINFER_MOE` | Use Flashinfer NVFP4 MoE kernel on Blackwell GPUs, otherwise use cutlass kernel | `true` | ## Distributed Computing diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ac1a831ac9d0..70bc9cefe0d8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1295,6 +1295,9 @@ def forward_deepgemm_masked( def get_moe_impl_class(): if global_server_args_dict["enable_deepep_moe"]: return DeepEPMoE + if global_server_args_dict["enable_flashinfer_ep_moe"]: + # Must come before enable_ep_moe + return FusedMoE if global_server_args_dict["enable_ep_moe"]: return EPMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 7cf8de28a48d..f690628e775a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -314,6 +314,7 @@ def __init__( inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + enable_ep_moe: Optional[bool] = False, ): super().__init__() @@ -324,12 +325,46 @@ def __init__( self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = num_experts + self.reduce_results = reduce_results + self.expert_map = None + self.enable_flashinfer_moe = get_bool_env_var( + "SGLANG_FLASHINFER_MOE", default="True" + ) + if enable_ep_moe: + assert ( + self.enable_flashinfer_moe + ), "FusedMoE only supports EP with --enable-flashinfer-moe" + self.reduce_results = True # combine needed + self.ep_size = self.tp_size + self.ep_rank = self.tp_rank + self.tp_size = 1 + self.tp_rank = 0 + # Create a tensor of size num_experts filled with -1 + self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) + # Create a expert map for the local experts + local_num_experts = num_experts // self.ep_size + if self.ep_rank < (self.ep_size - 1): + # Each non-last rank gets local_num_experts experts. + self.expert_map[ + self.ep_rank + * local_num_experts : (self.ep_rank + 1) + * local_num_experts + ] = torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = num_experts - self.ep_rank * local_num_experts + self.expert_map[-local_num_experts:] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + else: + self.ep_size = 1 + self.ep_rank = 0 self.routed_scaling_factor = routed_scaling_factor self.top_k = top_k - self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: @@ -344,7 +379,11 @@ def __init__( self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine - self.local_num_experts = num_experts + self.local_num_experts = ( + torch.sum(self.expert_map != -1) + if self.expert_map is not None + else num_experts + ) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -356,7 +395,7 @@ def __init__( self.quant_method.create_weights( layer=self, - num_experts=num_experts, + num_experts=self.local_num_experts, hidden_size=hidden_size, # FIXME: figure out which intermediate_size to use intermediate_size=self.intermediate_size_per_partition, @@ -450,12 +489,15 @@ def _load_w13( # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. + # trtllm cutlass kernel assumes differently + assert shard_id in ("w1", "w3") + switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False) + if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"): + start = shard_size else: - assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + start = 0 + expert_data = expert_data.narrow(shard_dim, start, shard_size) expert_data.copy_(loaded_weight) def _load_w2( @@ -509,6 +551,11 @@ def _load_g_idx( assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + def weight_loader( self, param: torch.nn.Parameter, @@ -517,6 +564,13 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + # TP rank is set to 0 if EP is enabled + tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -541,7 +595,6 @@ def weight_loader( SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -549,7 +602,7 @@ def weight_loader( is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - shard_dim = ~shard_dim + shard_dim = int(not shard_dim) # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: @@ -690,9 +743,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, routed_scaling_factor=self.routed_scaling_factor, + **( + dict( + tp_rank=self.tp_rank, + tp_size=self.tp_size, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + ) + if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod" + else {} + ), ) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index fed4d52dc74c..eae8446b0662 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -29,11 +29,17 @@ requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import is_cuda +from sglang.srt.utils import get_bool_env_var, is_cuda if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe +except ImportError: + flashinfer_cutlass_fused_moe = None + # Initialize logger for the module logger = logging.getLogger(__name__) @@ -521,6 +527,9 @@ def __init__(self, quant_config: ModelOptFp4Config): " quantization. Please use Blackwell and" " above." ) + self.enable_flashinfer_moe = get_bool_env_var( + "SGLANG_FLASHINFER_MOE", default="True" + ) def create_weights( self, @@ -727,11 +736,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.cutlass_moe_params = CutlassMoEParams( CutlassMoEType.BlockscaledFP4, device, - num_experts=layer.num_experts, + num_experts=layer.num_experts, # global num experts intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n hidden_size=layer.w13_weight.shape[2] * 2, ) # k + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + return self.enable_flashinfer_moe + def apply( self, layer: torch.nn.Module, @@ -750,6 +764,10 @@ def apply( inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + ep_rank: Optional[int] = None, + ep_size: Optional[int] = None, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." @@ -771,6 +789,46 @@ def apply( routed_scaling_factor=routed_scaling_factor, ) + if self.enable_flashinfer_moe: + assert ( + not apply_router_weight_on_input + ), "apply_router_weight_on_input is not supported for Flashinfer" + a1_gs = torch.min(layer.w13_input_scale_quant) + a2_gs = torch.min(layer.w2_input_scale_quant) + w1_blockscale = layer.w13_blockscale_swizzled + w2_blockscale = layer.w2_blockscale_swizzled + g1_alphas = layer.g1_alphas + g2_alphas = layer.g2_alphas + + quant_scales = [ + a1_gs, + w1_blockscale.view(torch.int32), + g1_alphas, + a2_gs, + w2_blockscale.view(torch.int32), + g2_alphas, + ] + # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision + # and fp4 quantized weights loaded from the checkpoint + out_dtype = x.dtype + output = x if inplace else torch.zeros_like(x) + x, x_sf = fp4_quantize(x, a1_gs) + output = flashinfer_cutlass_fused_moe( + x, + topk_ids.to(torch.int), + topk_weights, + layer.w13_weight.view(torch.long), + layer.w2_weight.view(torch.long), + out_dtype, + quant_scales=quant_scales, + input_sf=x_sf, + ep_size=ep_size, + ep_rank=ep_rank, + tp_size=tp_size, + tp_rank=tp_rank, + ) + return output[0] + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 return cutlass_moe_fp4( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6143c5575ce2..0c3786b04d38 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -85,6 +85,7 @@ "enable_deepep_moe", "deepep_mode", "enable_ep_moe", + "enable_flashinfer_ep_moe", "moe_dense_tp_size", "ep_dispatch_algorithm", "deepep_config", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 49c5edc602c7..450c84bc2fe5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -275,6 +275,11 @@ def __init__( if global_server_args_dict["enable_deepep_moe"] else {} ), + **( + dict(enable_ep_moe=True) + if global_server_args_dict["enable_flashinfer_ep_moe"] + else {} + ), ) if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 54e92d0bf781..b757465e3a38 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -152,6 +152,7 @@ class ServerArgs: ep_size: int = 1 enable_ep_moe: bool = False enable_deepep_moe: bool = False + enable_flashinfer_ep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None @@ -244,7 +245,15 @@ def __post_init__(self): logger.warning( f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - + if self.enable_flashinfer_ep_moe: + assert ( + self.quantization == "modelopt_fp4" + ), "modelopt_fp4 quantization is required for Flashinfer EP MOE" + self.enable_ep_moe = True + self.ep_size = self.tp_size + logger.warning( + f"FlashInfer EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -1166,6 +1175,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) + parser.add_argument( + "--enable-flashinfer-ep-moe", + action="store_true", + help="Enabling expert parallelism for moe using flashinfer backend. The ep size is equal to the tp size. ", + ) parser.add_argument( "--enable-deepep-moe", action="store_true", From 5f6ac95dceca13ef5c4a6d64ca9897033a691cbb Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 20 Jun 2025 18:53:49 +0000 Subject: [PATCH 2/8] Fix accuracy issue and clean up args --- docs/references/environment_variables.md | 1 - python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ++-- .../srt/layers/moe/fused_moe_triton/layer.py | 10 ++++++---- .../srt/layers/quantization/modelopt_quant.py | 4 +--- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 8 ++++++-- python/sglang/srt/server_args.py | 15 +++++++-------- 7 files changed, 23 insertions(+), 21 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index aeb6406175e6..2ce931b0343c 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -60,7 +60,6 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` | | `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` | | `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` | -| `SGLANG_FLASHINFER_MOE` | Use Flashinfer NVFP4 MoE kernel on Blackwell GPUs, otherwise use cutlass kernel | `true` | ## Distributed Computing diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 70bc9cefe0d8..38f123247cb9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1295,8 +1295,8 @@ def forward_deepgemm_masked( def get_moe_impl_class(): if global_server_args_dict["enable_deepep_moe"]: return DeepEPMoE - if global_server_args_dict["enable_flashinfer_ep_moe"]: - # Must come before enable_ep_moe + if global_server_args_dict["enable_flashinfer_moe"]: + # Must come before EPMoE because FusedMoE also supports enable_ep_moe return FusedMoE if global_server_args_dict["enable_ep_moe"]: return EPMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index f690628e775a..e711d0de4136 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -314,6 +314,7 @@ def __init__( inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + enable_flashinfer_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False, ): super().__init__() @@ -329,14 +330,13 @@ def __init__( self.num_experts = num_experts self.reduce_results = reduce_results self.expert_map = None - self.enable_flashinfer_moe = get_bool_env_var( - "SGLANG_FLASHINFER_MOE", default="True" - ) + self.enable_flashinfer_moe = enable_flashinfer_moe + if self.enable_flashinfer_moe: + self.reduce_results = True if enable_ep_moe: assert ( self.enable_flashinfer_moe ), "FusedMoE only supports EP with --enable-flashinfer-moe" - self.reduce_results = True # combine needed self.ep_size = self.tp_size self.ep_rank = self.tp_rank self.tp_size = 1 @@ -391,6 +391,8 @@ def __init__( ) else: self.quant_method = quant_config.get_quant_method(self, prefix) + if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod": + self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe assert self.quant_method is not None self.quant_method.create_weights( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index eae8446b0662..3efb0d5c17dc 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -527,9 +527,7 @@ def __init__(self, quant_config: ModelOptFp4Config): " quantization. Please use Blackwell and" " above." ) - self.enable_flashinfer_moe = get_bool_env_var( - "SGLANG_FLASHINFER_MOE", default="True" - ) + self.enable_flashinfer_moe = False def create_weights( self, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0c3786b04d38..615c3bc817db 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -85,7 +85,7 @@ "enable_deepep_moe", "deepep_mode", "enable_ep_moe", - "enable_flashinfer_ep_moe", + "enable_flashinfer_moe", "moe_dense_tp_size", "ep_dispatch_algorithm", "deepep_config", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 450c84bc2fe5..2a468cad6a44 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -275,9 +275,13 @@ def __init__( if global_server_args_dict["enable_deepep_moe"] else {} ), + # Additional args for FusedMoE **( - dict(enable_ep_moe=True) - if global_server_args_dict["enable_flashinfer_ep_moe"] + dict( + enable_flashinfer_moe=True, + enable_ep_moe=global_server_args_dict["enable_ep_moe"], + ) + if global_server_args_dict["enable_flashinfer_moe"] else {} ), ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b757465e3a38..2fcdeb296656 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -152,7 +152,7 @@ class ServerArgs: ep_size: int = 1 enable_ep_moe: bool = False enable_deepep_moe: bool = False - enable_flashinfer_ep_moe: bool = False + enable_flashinfer_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None @@ -245,14 +245,13 @@ def __post_init__(self): logger.warning( f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - if self.enable_flashinfer_ep_moe: + if self.enable_flashinfer_moe: assert ( self.quantization == "modelopt_fp4" - ), "modelopt_fp4 quantization is required for Flashinfer EP MOE" - self.enable_ep_moe = True - self.ep_size = self.tp_size + ), "modelopt_fp4 quantization is required for Flashinfer MOE" + self.disable_shared_experts_fusion = True logger.warning( - f"FlashInfer EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + f"Flashinfer MoE is enabled. Shared expert fusion is disabled." ) # Set missing default values if self.tokenizer_path is None: @@ -1176,9 +1175,9 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) parser.add_argument( - "--enable-flashinfer-ep-moe", + "--enable-flashinfer-moe", action="store_true", - help="Enabling expert parallelism for moe using flashinfer backend. The ep size is equal to the tp size. ", + help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", ) parser.add_argument( "--enable-deepep-moe", From 50830e1bb449c097cc696744e8b265ad1204cbbf Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 07:51:54 +0000 Subject: [PATCH 3/8] [fix] fix tp --- .../srt/layers/moe/fused_moe_triton/layer.py | 31 ++++-------- .../srt/layers/quantization/modelopt_quant.py | 49 +++++++++---------- 2 files changed, 31 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index e711d0de4136..6a82db210ff0 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -328,11 +328,8 @@ def __init__( ) self.tp_rank = get_tensor_model_parallel_rank() self.num_experts = num_experts - self.reduce_results = reduce_results self.expert_map = None self.enable_flashinfer_moe = enable_flashinfer_moe - if self.enable_flashinfer_moe: - self.reduce_results = True if enable_ep_moe: assert ( self.enable_flashinfer_moe @@ -344,27 +341,22 @@ def __init__( # Create a tensor of size num_experts filled with -1 self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) # Create a expert map for the local experts - local_num_experts = num_experts // self.ep_size - if self.ep_rank < (self.ep_size - 1): - # Each non-last rank gets local_num_experts experts. - self.expert_map[ - self.ep_rank - * local_num_experts : (self.ep_rank + 1) - * local_num_experts - ] = torch.arange(0, local_num_experts, dtype=torch.int32) - else: - # All remaining experts are assigned to the last rank. - local_num_experts = num_experts - self.ep_rank * local_num_experts - self.expert_map[-local_num_experts:] = torch.arange( - 0, local_num_experts, dtype=torch.int32 - ) + assert num_experts % self.ep_size == 0 + self.local_num_experts = num_experts // self.ep_size + self.expert_map[ + self.ep_rank + * self.local_num_experts : (self.ep_rank + 1) + * self.local_num_experts + ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu") else: self.ep_size = 1 self.ep_rank = 0 + self.local_num_experts = num_experts self.routed_scaling_factor = routed_scaling_factor self.top_k = top_k assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: @@ -379,11 +371,6 @@ def __init__( self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine - self.local_num_experts = ( - torch.sum(self.expert_map != -1) - if self.expert_map is not None - else num_experts - ) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 3efb0d5c17dc..d293a26aff96 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -29,7 +29,7 @@ requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_bool_env_var, is_cuda +from sglang.srt.utils import is_cuda, next_power_of_2 if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant @@ -681,7 +681,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + if self.enable_flashinfer_moe: + w13_input_scale = layer.w13_input_scale.max().to(torch.float32) + else: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, @@ -707,14 +710,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # GEMM 2 + if self.enable_flashinfer_moe: + w2_input_scale = layer.w2_input_scale.max().to(torch.float32) + else: + w2_input_scale = layer.w2_input_scale + layer.g2_alphas = Parameter( - (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + (1 / w2_input_scale).to(torch.float32), requires_grad=False ) assert ( @@ -769,8 +777,6 @@ def apply( ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts topk_weights, topk_ids = select_experts( @@ -791,39 +797,28 @@ def apply( assert ( not apply_router_weight_on_input ), "apply_router_weight_on_input is not supported for Flashinfer" - a1_gs = torch.min(layer.w13_input_scale_quant) - a2_gs = torch.min(layer.w2_input_scale_quant) - w1_blockscale = layer.w13_blockscale_swizzled - w2_blockscale = layer.w2_blockscale_swizzled - g1_alphas = layer.g1_alphas - g2_alphas = layer.g2_alphas - - quant_scales = [ - a1_gs, - w1_blockscale.view(torch.int32), - g1_alphas, - a2_gs, - w2_blockscale.view(torch.int32), - g2_alphas, - ] # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint - out_dtype = x.dtype - output = x if inplace else torch.zeros_like(x) - x, x_sf = fp4_quantize(x, a1_gs) output = flashinfer_cutlass_fused_moe( x, topk_ids.to(torch.int), topk_weights, layer.w13_weight.view(torch.long), layer.w2_weight.view(torch.long), - out_dtype, - quant_scales=quant_scales, - input_sf=x_sf, + x.dtype, + quant_scales=[ + layer.w13_input_scale_quant, + layer.w13_blockscale_swizzled.view(torch.int32), + layer.g1_alphas, + layer.w2_input_scale_quant, + layer.w2_blockscale_swizzled.view(torch.int32), + layer.g2_alphas, + ], ep_size=ep_size, ep_rank=ep_rank, tp_size=tp_size, tp_rank=tp_rank, + tune_max_num_tokens=next_power_of_2(x.shape[0]) ) return output[0] From ad02472d70015c2917afadcc5d2e7aa30a5cd829 Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 08:22:30 +0000 Subject: [PATCH 4/8] [opt] dual stream --- python/sglang/srt/models/deepseek_v2.py | 26 ++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 50f3e68a812e..acc7918f7cc2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -226,6 +226,7 @@ def __init__( layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -238,6 +239,7 @@ def __init__( ) self.config = config self.layer_id = layer_id + self.alt_stream = alt_stream if self.tp_size > config.n_routed_experts: raise ValueError( @@ -347,10 +349,31 @@ def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not self._enable_deepep_moe: - return self.forward_normal(hidden_states) + if self.alt_stream is not None and self.num_fused_shared_experts == 0: + return self.forward_normal_dual_stream(hidden_states) + else: + return self.forward_normal(hidden_states) else: return self.forward_deepep(hidden_states, forward_batch) + def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if not _is_cuda: + final_hidden_states *= self.routed_scaling_factor + current_stream.wait_stream(self.alt_stream) + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -1455,6 +1478,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("mlp", prefix), layer_id=self.layer_id, + alt_stream=alt_stream, ) else: if enable_moe_dense_fully_dp(): From e1a8dfb623d3c12ea87b5f12f960cf88ebdb8d3f Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 08:23:37 +0000 Subject: [PATCH 5/8] lint --- python/sglang/srt/layers/quantization/modelopt_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d293a26aff96..a5cf67a4d571 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -818,7 +818,7 @@ def apply( ep_rank=ep_rank, tp_size=tp_size, tp_rank=tp_rank, - tune_max_num_tokens=next_power_of_2(x.shape[0]) + tune_max_num_tokens=next_power_of_2(x.shape[0]), ) return output[0] From a2d918be1a30ca08652cdf19c98d0a70fd3eeab1 Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 08:47:50 +0000 Subject: [PATCH 6/8] [opt] pdl --- python/sglang/srt/server_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5dafe0a670d3..a6a2a0b90c0a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -249,6 +249,7 @@ def __post_init__(self): assert ( self.quantization == "modelopt_fp4" ), "modelopt_fp4 quantization is required for Flashinfer MOE" + os.environ["TRTLLM_ENABLE_PDL"] = "1" self.disable_shared_experts_fusion = True logger.warning( f"Flashinfer MoE is enabled. Shared expert fusion is disabled." From 6a8c3865ea096216cb921c8efd3b3f51c1040b69 Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 10:19:14 +0000 Subject: [PATCH 7/8] opt --- python/sglang/srt/layers/quantization/modelopt_quant.py | 5 ++++- python/sglang/srt/models/deepseek_v2.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index a5cf67a4d571..913a5bb99eea 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -435,6 +435,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.alpha = Parameter( layer.input_scale * layer.weight_scale_2, requires_grad=False ) + layer.input_scale_inv = Parameter( + (1 / input_scale_2).to(torch.float32), requires_grad=False + ) # Pad and blockwise interleave weight_scale scales = layer.weight_scale @@ -473,7 +476,7 @@ def apply( output_shape = [x_m, w_n] # Quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale) + x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv) assert x_fp4.dtype == torch.uint8 assert x_scale_interleaved.dtype == torch.float8_e4m3fn diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index acc7918f7cc2..d8c3fbc7641e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -349,7 +349,12 @@ def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not self._enable_deepep_moe: - if self.alt_stream is not None and self.num_fused_shared_experts == 0: + DUAL_STREAM_TOKEN_THRESHOLD = 2048 + if ( + self.alt_stream is not None + and self.num_fused_shared_experts == 0 + and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + ): return self.forward_normal_dual_stream(hidden_states) else: return self.forward_normal(hidden_states) From 64459c357e1b76815a31c0d6c8f1906558ba29fa Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sun, 22 Jun 2025 10:55:23 +0000 Subject: [PATCH 8/8] update threshold --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d8c3fbc7641e..0f306c28e401 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -349,7 +349,7 @@ def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not self._enable_deepep_moe: - DUAL_STREAM_TOKEN_THRESHOLD = 2048 + DUAL_STREAM_TOKEN_THRESHOLD = 1024 if ( self.alt_stream is not None and self.num_fused_shared_experts == 0