Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
1e53e23
support w4a8 low latency deepep
ayrnb Jul 23, 2025
93cb396
clean code
ayrnb Jul 24, 2025
31d01f9
clean code
ayrnb Jul 24, 2025
157b979
clean code
ayrnb Jul 24, 2025
5dd0f87
[bug] fix pd completion protocol for batching support (#8317)
slin1237 Jul 24, 2025
f6e07f2
[router] fix pd model completion request (#8303)
slin1237 Jul 24, 2025
bfb118c
fix bug when eos_ids==0 (#8315)
bzantium Jul 24, 2025
2f86f3a
[router] add endpoint unit test (#8298)
slin1237 Jul 24, 2025
a167fd0
[code style] Clean dead triton kernel code in fused_moe and useless v…
BBuf Jul 24, 2025
96c5d85
fix
ayrnb Jul 24, 2025
0090240
fix
ayrnb Jul 24, 2025
8d1c5b9
chore: upgrade flashinfer v0.2.9rc1 (#8301)
Swipe4057 Jul 24, 2025
33c4b4d
[router] add streaming unit test (#8299)
slin1237 Jul 24, 2025
39fe1e8
[router] add request format unit test (#8300)
slin1237 Jul 24, 2025
145482f
HiCache Storage TP Refinement (#8307)
xiezhq-hermann Jul 25, 2025
d40846d
breakdown kernel update (#8334)
xiezhq-hermann Jul 25, 2025
f4674df
support idle batch for TBO (#8233)
sherry-1001 Jul 25, 2025
28d4d47
[Feature] Integrate quick allreduce and select the best allreduce imp…
lihaoyang-amd Jul 25, 2025
c0fb25e
DP Enhancement (#8280)
ch-wan Jul 25, 2025
7ad6b76
fix: Fix failed functional tests https://github.com/meta-llama/llama-…
ynwang007 Jul 25, 2025
af4b9ba
[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_qui…
hubertlu-tw Jul 25, 2025
15d2759
[CPU] Add tutorial docs for SGL on CPU (#8000)
ZailiWang Jul 25, 2025
70e37b9
chore: upgrade mooncake 0.3.5 (#8341)
ShangmingCai Jul 25, 2025
9045cc1
[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly tr…
BBuf Jul 25, 2025
1b9cea5
[P/D] Support ipv6 in P/D scenario (#7858)
thefacetakt Jul 25, 2025
12cb760
Add H20-3e fused MoE kernel tuning configs for Qwen3-Coder-480B-A35B-…
Xu-Wenqing Jul 25, 2025
f8260f2
[Bugfix][Feat] Add XML-ish grammar in EBNFComposer and fix misc bugs …
CatherineSue Jul 25, 2025
ed2e313
Clean up server_args, triton cache manager (#8332)
merrymercy Jul 25, 2025
7181ec8
fix: upgrade nccl version (#8359)
zhyncs Jul 25, 2025
d8ee156
[Feat] Add reasoning parser for Qwen/Qwen3-235B-A22B-Thinking-2507 (#…
CatherineSue Jul 25, 2025
f8ca236
fix: kimi k2 xgrammar crash (#8367)
zhyncs Jul 25, 2025
58c468f
Fix FP4 MoE accuracy from missing routed_scaling_factor (#8333)
trevor-m Jul 25, 2025
3ec0b21
[CI] Fix flaky threshold (#8370)
merrymercy Jul 25, 2025
2272c2a
chore: bump v0.4.9.post4 (#8305)
zhyncs Jul 26, 2025
8af145b
Fix test_moe_fused_gate_combined sgl-kernel ci test (#8374)
ispobock Jul 26, 2025
e6312d2
Uodate Dockerfile.gb200 to latest sglang (#8356)
kyleliang-nv Jul 26, 2025
4fa44d6
chore: improve mmmu benchmark (#7000)
mickqian Jul 26, 2025
e236d8f
Save peak memory in logits processor (#8343)
ch-wan Jul 26, 2025
ce32bc2
Extract update_weights from RL Engine to SGLang to keep simplicity an…
hebiao064 Jul 26, 2025
5347567
chore: improvements on mm_utils (#7737)
mickqian Jul 26, 2025
3212c2a
vlm: optimize tensor transport (#6003)
mickqian Jul 26, 2025
da0c026
Tiny assert EPLB is used together with expert parallel (#8381)
fzyzcjy Jul 26, 2025
b7094a5
model: support intern-s1 (#8350)
RunningLeon Jul 26, 2025
5c705b1
Add perf tests for LoRA (#8314)
lifuhuang Jul 26, 2025
7615463
Remove slot usage in code to be backward-compatible with python 3.9 (…
lifuhuang Jul 27, 2025
62a6b7c
Add docker release flow for gb200 (#8394)
kyleliang-nv Jul 27, 2025
528bd1e
HiCache, check before terminate prefetching (#8372)
xiezhq-hermann Jul 27, 2025
426b749
Add nvfp4 scaled mm benchmark. (#8401)
HydraQYH Jul 27, 2025
b602f42
Urgent Fix: intern-s1 chat-template matching (#8403)
JustinTong0323 Jul 27, 2025
ed0fdbf
Tool to dump and compare internal activation tensors (#7976)
fzyzcjy Jul 27, 2025
62222bd
Minor tool for comparison of benchmark results (#7974)
fzyzcjy Jul 27, 2025
e34cf6a
Fix bench script making input data on L2 cache (#7739)
fzyzcjy Jul 27, 2025
85486b6
[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)
kaixih Jul 27, 2025
91e3d15
Update Cutlass in sgl-kernel to v4.1 (#8392)
Fridge003 Jul 27, 2025
0bcc195
fix: minor fix TransportProxyTensor under tp (#8382)
mickqian Jul 27, 2025
2ab9702
[router] add different policies for p node and d node (#8395)
slin1237 Jul 27, 2025
2a1936d
Add A800 fused MoE kernel tuning configs for Qwen3-Coder-480B-A35B-In…
lambert0312 Jul 27, 2025
36d6f0b
fix: fix the missing metrics on non-rank0 nodes (#7720)
acelyc111 Jul 27, 2025
bf0f448
[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)
ch-wan Jul 27, 2025
5c9c275
Use FlashInfer FP4 gemm. (#8241)
elfiegg Jul 27, 2025
44d600c
Support precomputed_embeddings for Llama 4 (#8156)
AlienKevin Jul 27, 2025
4d921f2
[hotfix] fix merge conflicts in FlashInferEPMoE (#8405)
ch-wan Jul 27, 2025
bf3352c
chore: update CODEOWNERS (#8407)
zhyncs Jul 27, 2025
10ee895
chore: upgrade flashinfer v0.2.9rc2 (#8406)
zhyncs Jul 27, 2025
b3eac16
Support triton kernels v3.4.0 for fused_moe (#8258)
yuan-luo Jul 27, 2025
22e00ee
[Bugfix] Prevent PD server crash from invalid grammar (#8062)
ShangmingCai Jul 27, 2025
95217a9
Change to use native arm runner (#8414)
kyleliang-nv Jul 27, 2025
df90645
Support overlapped lora updates (#8213)
lifuhuang Jul 27, 2025
b58c3c2
Support ue8m0 for triton quant kernel (#7603)
fzyzcjy Jul 27, 2025
e983d66
Fix: Improve test_openai_function_calling unit test and fix reasoning…
byjiang1996 Jul 27, 2025
b47eda3
bugfix: Fix multiple finish_reason chunks and tool_calls finish reaso…
CatherineSue Jul 27, 2025
58dd95f
Fix test_openai_server (#8419)
CatherineSue Jul 27, 2025
bb81dae
Fix docker buildx push error (#8425)
kyleliang-nv Jul 28, 2025
dd487e5
bugfix: Fix XGrammar backend to use model's EOS tokens for constraine…
CatherineSue Jul 28, 2025
fe6a445
[router] improve router logs and request id header (#8415)
slin1237 Jul 28, 2025
2810338
[feat] Support different attention backends for prefill and decode (…
Qiaolin-Yu Jul 28, 2025
4ad9737
chore: bump transformer to 4.54.0 (#8416)
hebiao064 Jul 28, 2025
2fd5c70
[PD] Fix abort_request for PD disaggregation (#8352)
ShangmingCai Jul 28, 2025
6d6a8bc
GLM-4.5 Model Support (#8224)
zRzRzRzRzRzRzR Jul 28, 2025
5922c0c
Remove zstd compression for building Dockerfile.gb200 (#8442)
kyleliang-nv Jul 28, 2025
484d0e0
doc: add bench_one_batch_server in the benchmark doc (#8441)
Qiaolin-Yu Jul 28, 2025
581e7dc
GLM-4.5 Model Support Follow-up (#8445)
byjiang1996 Jul 28, 2025
25f73c6
fix GLM4_MOE launch with compressed_tensor quant model (#8456)
zminglei Jul 28, 2025
fb4ce17
Fix per_token_group_quant_8bit when hidden_dim // group_size is not d…
strgrb Jul 28, 2025
2262369
Revert "[kernel] opt moe align block kernel by block/warp scan algori…
BBuf Jul 28, 2025
45bc170
chore: bump v0.4.9.post5 (#8458)
zhyncs Jul 28, 2025
a9dd3ec
fix:reorder topk experts to ensure shared expert replaces minimal sco…
erictanjn Jul 28, 2025
712877a
support w4a8 low latency deepep
ayrnb Jul 23, 2025
77351b7
clean code
ayrnb Jul 24, 2025
c15e34a
clean code
ayrnb Jul 24, 2025
f770ea6
clean code
ayrnb Jul 24, 2025
cfe7d62
fix
ayrnb Jul 24, 2025
d2afdb4
fix
ayrnb Jul 24, 2025
eb39568
Merge branch 'feat/w4a8_support_ll_deepep' of github.com:bytedance-ia…
ayrnb Jul 28, 2025
1e721d4
support cudagraph
ayrnb Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 98 additions & 51 deletions python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_ll_get_cutlass_w4a8_moe_mm_data,
post_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
Expand Down Expand Up @@ -43,6 +44,7 @@ def cutlass_w4a8_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
ep_mode: str = "ep",
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
Expand Down Expand Up @@ -83,10 +85,14 @@ def cutlass_w4a8_moe(
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert (
topk_weights.shape == topk_ids_.shape if topk_weights is not None else True
), "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert (
a.shape[1] // 2 == w1_q.shape[2] if ep_mode != "deepep_ll" else True
), "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
Expand All @@ -108,52 +114,79 @@ def cutlass_w4a8_moe(
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
topk = topk_ids_.size(1) if ep_mode == "ep" else 1

if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"

device = a.device

_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
)
if ep_mode == "ep":
_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
)

gateup_input = torch.empty(
(m * topk, k),
device=device,
dtype=torch.float8_e4m3fn,
)
gateup_input = torch.empty(
(m * topk, k),
device=device,
dtype=torch.float8_e4m3fn,
)

pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
a,
gateup_input,
src2dst,
local_topk_ids,
a1_scale,
total_num_experts,
topk,
k,
BLOCK_SIZE=512,
)
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
a,
gateup_input,
src2dst,
local_topk_ids,
a1_scale,
total_num_experts,
topk,
k,
BLOCK_SIZE=512,
)
elif ep_mode == "deepep_ll":
num_tokens = a.size(1)

else:
raise ValueError(f"Invalid ep_mode: {ep_mode}")

# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
if ep_mode == "deepep_ll":
gateup_input_origin, expert_offsets, problem_sizes1, problem_sizes2 = (
deepep_ll_get_cutlass_w4a8_moe_mm_data(
a,
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
)
)
gateup_input = torch.empty(
gateup_input_origin.shape, dtype=torch.float8_e4m3fn, device=device
)
sgl_per_tensor_quant_fp8(
gateup_input_origin, gateup_input, a1_scale.float(), True
)

else:
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)

c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
Expand Down Expand Up @@ -197,19 +230,33 @@ def cutlass_w4a8_moe(
128,
topk,
)

output = torch.empty_like(a)
post_reorder_triton_kernel[(m,)](
c2,
output,
src2dst,
topk_ids_,
topk_weights,
start_expert_id,
end_expert_id,
topk,
k,
0,
BLOCK_SIZE=512,
)
if ep_mode == "ep":
output = torch.empty_like(a)
post_reorder_triton_kernel[(m,)](
c2,
output,
src2dst,
local_topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
k,
0,
BLOCK_SIZE=512,
)
elif ep_mode == "deepep_ll":
output = torch.zeros(
(len(local_topk_ids), num_tokens, k), device=device, dtype=c2.dtype
)
non_zero_indices = torch.nonzero(local_topk_ids, as_tuple=True)[0]
c2_index = 0
for expert_idx in non_zero_indices:
num_non_zero_rows = local_topk_ids[expert_idx].item()
output[expert_idx, :num_non_zero_rows] = c2[
c2_index : c2_index + num_non_zero_rows
]
c2_index += num_non_zero_rows
Comment on lines +258 to +263
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Python loop iterates over active experts to scatter the results. For performance-critical code running on a GPU, this can be a bottleneck due to the overhead of launching multiple operations from a Python loop. Consider vectorizing this operation or using a custom kernel for a more efficient implementation.

else:
output = c2
Comment on lines +264 to +265
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This else block appears to be unreachable. The ep_mode is validated on lines 124-151, and an unknown ep_mode will raise a ValueError, preventing execution from reaching this point. This makes the else block dead code. Please remove it to improve code clarity and maintainability.

return output
72 changes: 72 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,3 +1319,75 @@ def moe_ep_deepgemm_preprocess(
gateup_input,
gateup_input_scale,
)


@triton.jit
def compute_problem_sizes_w4a8_kernel(
masked_m_ptr,
problem_sizes1_ptr,
problem_sizes2_ptr,
n,
k,
num_experts,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pid < num_experts
final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
tl.store(problem_sizes1_ptr + pid * 3, 2 * n)
tl.store(problem_sizes1_ptr + pid * 3 + 1, final_occurrences)
tl.store(problem_sizes1_ptr + pid * 3 + 2, k)
tl.store(problem_sizes2_ptr + pid * 3, k)
tl.store(problem_sizes2_ptr + pid * 3 + 1, final_occurrences)
tl.store(problem_sizes2_ptr + pid * 3 + 2, n)


def compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
compute_problem_sizes_w4a8_kernel[grid](
masked_m,
problem_sizes1,
problem_sizes2,
n,
k,
num_experts,
BLOCK_SIZE=BLOCK_SIZE,
)
return problem_sizes1, problem_sizes2


def deepep_ll_get_cutlass_w4a8_moe_mm_data(
hidden_states,
masked_m,
expert_offsets,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
):
problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
)
masked_m_with_zero = torch.cat(
[torch.tensor([0], device=masked_m.device, dtype=masked_m.dtype), masked_m],
dim=0,
)
expert_offsets = torch.cumsum(masked_m_with_zero, dim=0)
expert_indices = torch.nonzero(masked_m, as_tuple=True)[0]
hidden_states_real = hidden_states[expert_indices]
hidden_states_real_reshaped = hidden_states_real.view(
-1, hidden_states_real.size(-1)
)
non_zero_rows_mask = (hidden_states_real_reshaped != 0).any(dim=-1)
hidden_states_real_reshaped = hidden_states_real_reshaped[non_zero_rows_mask]
logger.info(f"masked_m {masked_m}")
return (
hidden_states_real_reshaped,
expert_offsets.to(torch.int32),
problem_sizes1.to(torch.int32),
problem_sizes2.to(torch.int32),
)
46 changes: 45 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@
get_bool_env_var,
is_hip,
is_npu,
set_weight_attrs,
)

_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip


if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul

Expand Down Expand Up @@ -954,10 +956,52 @@ def forward(
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
if self.use_w4afp8:
return self.forward_cutlass_w4a8_masked(
hidden_states, masked_m, ep_mode="deepep_ll"
)
else:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

def forward_cutlass_w4a8_masked(
self, hidden_states: torch.Tensor, masked_m: torch.Tensor, ep_mode: str
):

total_m = torch.sum(masked_m)
if total_m > 0:
output = cutlass_w4a8_moe(
self.start_expert_id,
self.end_expert_id,
self.num_experts,
hidden_states,
self.w13_weight,
self.w2_weight,
self.w13_weight_scale_inv,
self.w2_weight_scale_inv,
None,
None,
masked_m,
self.quant_method.a_strides1,
self.quant_method.b_strides1,
self.quant_method.c_strides1,
self.quant_method.a_strides2,
self.quant_method.b_strides2,
self.quant_method.c_strides2,
self.quant_method.s_strides13,
self.quant_method.s_strides2,
self.quant_method.expert_offsets,
self.quant_method.problem_sizes1,
self.quant_method.problem_sizes2,
self.w13_input_scale,
self.w2_input_scale,
ep_mode=ep_mode,
)
return output.to(torch.bfloat16)
else:
return hidden_states.to(torch.bfloat16)

def forward_normal(
self,
hidden_states: torch.Tensor,
Expand Down
Loading