Skip to content

Commit 5962e70

Browse files
FlashInfer NVFP4 MoE with EP & 2-stream shared expert (#7327)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: alcanderian <alcanderian@gmail.com>
1 parent edc21cc commit 5962e70

File tree

6 files changed

+182
-20
lines changed

6 files changed

+182
-20
lines changed

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,9 @@ def forward_deepgemm_masked(
12951295
def get_moe_impl_class():
12961296
if global_server_args_dict["enable_deepep_moe"]:
12971297
return DeepEPMoE
1298+
if global_server_args_dict["enable_flashinfer_moe"]:
1299+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
1300+
return FusedMoE
12981301
if global_server_args_dict["enable_ep_moe"]:
12991302
return EPMoE
13001303
return FusedMoE

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def __init__(
314314
inplace: bool = True,
315315
no_combine: bool = False,
316316
routed_scaling_factor: Optional[float] = None,
317+
enable_flashinfer_moe: Optional[bool] = False,
318+
enable_ep_moe: Optional[bool] = False,
317319
):
318320
super().__init__()
319321

@@ -324,9 +326,34 @@ def __init__(
324326
self.tp_size = (
325327
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
326328
)
329+
self.tp_rank = get_tensor_model_parallel_rank()
330+
self.num_experts = num_experts
331+
self.expert_map = None
332+
self.enable_flashinfer_moe = enable_flashinfer_moe
333+
if enable_ep_moe:
334+
assert (
335+
self.enable_flashinfer_moe
336+
), "FusedMoE only supports EP with --enable-flashinfer-moe"
337+
self.ep_size = self.tp_size
338+
self.ep_rank = self.tp_rank
339+
self.tp_size = 1
340+
self.tp_rank = 0
341+
# Create a tensor of size num_experts filled with -1
342+
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
343+
# Create a expert map for the local experts
344+
assert num_experts % self.ep_size == 0
345+
self.local_num_experts = num_experts // self.ep_size
346+
self.expert_map[
347+
self.ep_rank
348+
* self.local_num_experts : (self.ep_rank + 1)
349+
* self.local_num_experts
350+
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
351+
else:
352+
self.ep_size = 1
353+
self.ep_rank = 0
354+
self.local_num_experts = num_experts
327355
self.routed_scaling_factor = routed_scaling_factor
328356
self.top_k = top_k
329-
self.num_experts = num_experts
330357
assert intermediate_size % self.tp_size == 0
331358
self.intermediate_size_per_partition = intermediate_size // self.tp_size
332359
self.reduce_results = reduce_results
@@ -344,19 +371,20 @@ def __init__(
344371
self.use_presharded_weights = use_presharded_weights
345372
self.inplace = inplace
346373
self.no_combine = no_combine
347-
self.local_num_experts = num_experts
348374

349375
if quant_config is None:
350376
self.quant_method: Optional[QuantizeMethodBase] = (
351377
UnquantizedFusedMoEMethod()
352378
)
353379
else:
354380
self.quant_method = quant_config.get_quant_method(self, prefix)
381+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
382+
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
355383
assert self.quant_method is not None
356384

357385
self.quant_method.create_weights(
358386
layer=self,
359-
num_experts=num_experts,
387+
num_experts=self.local_num_experts,
360388
hidden_size=hidden_size,
361389
# FIXME: figure out which intermediate_size to use
362390
intermediate_size=self.intermediate_size_per_partition,
@@ -450,12 +478,15 @@ def _load_w13(
450478

451479
# Narrow parameter and load.
452480
# w1, gate_proj: Load into first logical weight of w13.
453-
if shard_id == "w1":
454-
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
455481
# w3, up_proj: Load into second logical weight of w13.
482+
# trtllm cutlass kernel assumes differently
483+
assert shard_id in ("w1", "w3")
484+
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
485+
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
486+
start = shard_size
456487
else:
457-
assert shard_id == "w3"
458-
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
488+
start = 0
489+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
459490
expert_data.copy_(loaded_weight)
460491

461492
def _load_w2(
@@ -509,6 +540,11 @@ def _load_g_idx(
509540
assert shard_id in ("w1", "w3")
510541
expert_data.copy_(loaded_weight)
511542

543+
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
544+
if self.expert_map is None:
545+
return expert_id
546+
return self.expert_map[expert_id].item()
547+
512548
def weight_loader(
513549
self,
514550
param: torch.nn.Parameter,
@@ -517,6 +553,13 @@ def weight_loader(
517553
shard_id: str,
518554
expert_id: int,
519555
) -> None:
556+
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
557+
if expert_id == -1:
558+
return
559+
560+
# TP rank is set to 0 if EP is enabled
561+
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
562+
520563
# compressed-tensors checkpoints with packed weights are stored flipped
521564
# TODO (mgoin): check self.quant_method.quant_config.quant_format
522565
# against known CompressionFormat enum values that have this quality
@@ -541,15 +584,14 @@ def weight_loader(
541584
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
542585

543586
expert_data = param.data[expert_id]
544-
tp_rank = get_tensor_model_parallel_rank()
545587

546588
# is_transposed: if the dim to shard the weight
547589
# should be flipped. Required by GPTQ, compressed-tensors
548590
# should be whatever dimension intermediate_size is
549591
is_transposed = getattr(param, "is_transposed", False)
550592
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
551593
if is_transposed:
552-
shard_dim = ~shard_dim
594+
shard_dim = int(not shard_dim)
553595

554596
# Case input scale: input_scale loading is only supported for fp8
555597
if "input_scale" in weight_name:
@@ -690,9 +732,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
690732
activation=self.activation,
691733
apply_router_weight_on_input=self.apply_router_weight_on_input,
692734
routed_scaling_factor=self.routed_scaling_factor,
735+
**(
736+
dict(
737+
tp_rank=self.tp_rank,
738+
tp_size=self.tp_size,
739+
ep_rank=self.ep_rank,
740+
ep_size=self.ep_size,
741+
)
742+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
743+
else {}
744+
),
693745
)
694746

695-
if self.reduce_results and self.tp_size > 1:
747+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
696748
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
697749

698750
return final_hidden_states

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@
2929
requantize_with_max_scale,
3030
)
3131
from sglang.srt.layers.radix_attention import RadixAttention
32-
from sglang.srt.utils import is_cuda
32+
from sglang.srt.utils import is_cuda, next_power_of_2
3333

3434
if is_cuda():
3535
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
3636

37+
try:
38+
from flashinfer import fp4_quantize as fp4_quantize
39+
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
40+
except ImportError:
41+
flashinfer_cutlass_fused_moe = None
42+
3743
# Initialize logger for the module
3844
logger = logging.getLogger(__name__)
3945

@@ -429,6 +435,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
429435
layer.alpha = Parameter(
430436
layer.input_scale * layer.weight_scale_2, requires_grad=False
431437
)
438+
layer.input_scale_inv = Parameter(
439+
(1 / input_scale_2).to(torch.float32), requires_grad=False
440+
)
432441

433442
# Pad and blockwise interleave weight_scale
434443
scales = layer.weight_scale
@@ -467,7 +476,7 @@ def apply(
467476
output_shape = [x_m, w_n]
468477

469478
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
470-
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
479+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
471480

472481
assert x_fp4.dtype == torch.uint8
473482
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
@@ -521,6 +530,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
521530
" quantization. Please use Blackwell and"
522531
" above."
523532
)
533+
self.enable_flashinfer_moe = False
524534

525535
def create_weights(
526536
self,
@@ -674,7 +684,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
674684
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
675685
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
676686

677-
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
687+
if self.enable_flashinfer_moe:
688+
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
689+
else:
690+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
678691
layer.g1_alphas = Parameter(
679692
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
680693
requires_grad=False,
@@ -700,14 +713,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
700713
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
701714

702715
# GEMM 2
716+
if self.enable_flashinfer_moe:
717+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
718+
else:
719+
w2_input_scale = layer.w2_input_scale
720+
703721
layer.g2_alphas = Parameter(
704-
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
722+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
705723
requires_grad=False,
706724
)
707725

708726
# This is for quantization, so we need to invert it.
709727
layer.w2_input_scale_quant = Parameter(
710-
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
728+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
711729
)
712730

713731
assert (
@@ -727,11 +745,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
727745
layer.cutlass_moe_params = CutlassMoEParams(
728746
CutlassMoEType.BlockscaledFP4,
729747
device,
730-
num_experts=layer.num_experts,
748+
num_experts=layer.num_experts, # global num experts
731749
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
732750
hidden_size=layer.w13_weight.shape[2] * 2,
733751
) # k
734752

753+
@property
754+
def load_up_proj_weight_first(self) -> bool:
755+
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
756+
return self.enable_flashinfer_moe
757+
735758
def apply(
736759
self,
737760
layer: torch.nn.Module,
@@ -750,11 +773,13 @@ def apply(
750773
inplace: bool = True,
751774
no_combine: bool = False,
752775
routed_scaling_factor: Optional[float] = None,
776+
ep_rank: Optional[int] = None,
777+
ep_size: Optional[int] = None,
778+
tp_rank: Optional[int] = None,
779+
tp_size: Optional[int] = None,
753780
) -> torch.Tensor:
754781

755782
assert activation == "silu", "Only SiLU activation is supported."
756-
757-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
758783
from sglang.srt.layers.moe.topk import select_experts
759784

760785
topk_weights, topk_ids = select_experts(
@@ -771,6 +796,35 @@ def apply(
771796
routed_scaling_factor=routed_scaling_factor,
772797
)
773798

799+
if self.enable_flashinfer_moe:
800+
assert (
801+
not apply_router_weight_on_input
802+
), "apply_router_weight_on_input is not supported for Flashinfer"
803+
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
804+
# and fp4 quantized weights loaded from the checkpoint
805+
output = flashinfer_cutlass_fused_moe(
806+
x,
807+
topk_ids.to(torch.int),
808+
topk_weights,
809+
layer.w13_weight.view(torch.long),
810+
layer.w2_weight.view(torch.long),
811+
x.dtype,
812+
quant_scales=[
813+
layer.w13_input_scale_quant,
814+
layer.w13_blockscale_swizzled.view(torch.int32),
815+
layer.g1_alphas,
816+
layer.w2_input_scale_quant,
817+
layer.w2_blockscale_swizzled.view(torch.int32),
818+
layer.g2_alphas,
819+
],
820+
ep_size=ep_size,
821+
ep_rank=ep_rank,
822+
tp_size=tp_size,
823+
tp_rank=tp_rank,
824+
tune_max_num_tokens=next_power_of_2(x.shape[0]),
825+
)
826+
return output[0]
827+
774828
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
775829

776830
return cutlass_moe_fp4(

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"enable_deepep_moe",
8787
"deepep_mode",
8888
"enable_ep_moe",
89+
"enable_flashinfer_moe",
8990
"moe_dense_tp_size",
9091
"ep_dispatch_algorithm",
9192
"deepep_config",

python/sglang/srt/models/deepseek_v2.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
layer_id: int,
227227
quant_config: Optional[QuantizationConfig] = None,
228228
prefix: str = "",
229+
alt_stream: Optional[torch.cuda.Stream] = None,
229230
):
230231
super().__init__()
231232
self.tp_size = get_tensor_model_parallel_world_size()
@@ -238,6 +239,7 @@ def __init__(
238239
)
239240
self.config = config
240241
self.layer_id = layer_id
242+
self.alt_stream = alt_stream
241243

242244
if self.tp_size > config.n_routed_experts:
243245
raise ValueError(
@@ -275,6 +277,15 @@ def __init__(
275277
if global_server_args_dict["enable_deepep_moe"]
276278
else {}
277279
),
280+
# Additional args for FusedMoE
281+
**(
282+
dict(
283+
enable_flashinfer_moe=True,
284+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
285+
)
286+
if global_server_args_dict["enable_flashinfer_moe"]
287+
else {}
288+
),
278289
)
279290

280291
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
@@ -338,10 +349,36 @@ def forward(
338349
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
339350
) -> torch.Tensor:
340351
if not self._enable_deepep_moe:
341-
return self.forward_normal(hidden_states)
352+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
353+
if (
354+
self.alt_stream is not None
355+
and self.num_fused_shared_experts == 0
356+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
357+
):
358+
return self.forward_normal_dual_stream(hidden_states)
359+
else:
360+
return self.forward_normal(hidden_states)
342361
else:
343362
return self.forward_deepep(hidden_states, forward_batch)
344363

364+
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
365+
current_stream = torch.cuda.current_stream()
366+
self.alt_stream.wait_stream(current_stream)
367+
shared_output = self._forward_shared_experts(hidden_states)
368+
with torch.cuda.stream(self.alt_stream):
369+
# router_logits: (num_tokens, n_experts)
370+
router_logits = self.gate(hidden_states)
371+
final_hidden_states = self.experts(
372+
hidden_states=hidden_states, router_logits=router_logits
373+
)
374+
if not _is_cuda:
375+
final_hidden_states *= self.routed_scaling_factor
376+
current_stream.wait_stream(self.alt_stream)
377+
final_hidden_states = final_hidden_states + shared_output
378+
if self.tp_size > 1:
379+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
380+
return final_hidden_states
381+
345382
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
346383
shared_output = self._forward_shared_experts(hidden_states)
347384
# router_logits: (num_tokens, n_experts)
@@ -1446,6 +1483,7 @@ def __init__(
14461483
quant_config=quant_config,
14471484
prefix=add_prefix("mlp", prefix),
14481485
layer_id=self.layer_id,
1486+
alt_stream=alt_stream,
14491487
)
14501488
else:
14511489
if enable_moe_dense_fully_dp():

0 commit comments

Comments
 (0)