Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4b53bc9
upd
BBuf Mar 30, 2025
953a000
upd
BBuf Mar 30, 2025
500e3e2
Merge branch 'main' into support_r1_shared_expers_fusion
zhyncs Mar 30, 2025
5dac1c2
upd
BBuf Mar 31, 2025
5cff889
Merge branch 'support_r1_shared_expers_fusion' of github.com:sgl-proj…
BBuf Mar 31, 2025
e771349
fix acc bug
BBuf Mar 31, 2025
c69675a
upd
BBuf Mar 31, 2025
4180d63
fix circular import
BBuf Mar 31, 2025
128480e
upd
BBuf Mar 31, 2025
3d223ba
upd
BBuf Apr 1, 2025
3fd2706
upd
BBuf Apr 1, 2025
6c33e52
upd
BBuf Apr 1, 2025
5d892f8
Merge branch 'main' into support_r1_shared_expers_fusion
zhyncs Apr 1, 2025
4bf262a
upd
BBuf Apr 1, 2025
42678e1
upd
BBuf Apr 1, 2025
3bfa90a
upd
BBuf Apr 1, 2025
e3c5c3d
fix amd
BBuf Apr 1, 2025
2a5af12
fix ci
BBuf Apr 1, 2025
2f185f9
refine
BBuf Apr 1, 2025
3bb4fc1
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 2, 2025
99abc77
refine
BBuf Apr 2, 2025
f8c8c70
refine
BBuf Apr 2, 2025
2a4bc93
ud
BBuf Apr 2, 2025
9c797af
ud
BBuf Apr 2, 2025
0261301
ud
BBuf Apr 2, 2025
cd3782d
upd
BBuf Apr 2, 2025
3d8a840
refine
BBuf Apr 2, 2025
2419300
upd
BBuf Apr 2, 2025
90ee3f8
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 2, 2025
1df1bee
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 3, 2025
40dc2c6
fix ci
BBuf Apr 3, 2025
fb5b17a
lint
BBuf Apr 3, 2025
943e986
add warmup for bench_serving tools
BBuf Apr 3, 2025
dab29e2
refine
BBuf Apr 3, 2025
1bba429
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 3, 2025
556cba3
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 3, 2025
64f261e
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 3, 2025
9a6832a
fix bug
BBuf Apr 3, 2025
dbcae93
Merge branch 'main' into support_r1_shared_expers_fusion
BBuf Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import os
import argparse
import json
import time
Expand Down Expand Up @@ -400,6 +401,9 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0"))
if n_share_fusion_experts > 0:
E = E + n_share_fusion_experts
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepSeek-V2 has 2 shared experts. Should we multiple the number of replica with the number of shared experts?

topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import triton.language as tl

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
Expand Down Expand Up @@ -510,6 +509,8 @@ def invoke_fused_moe_kernel(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None:
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8

assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

Expand Down Expand Up @@ -638,6 +639,10 @@ def get_moe_configs(
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) > 0:
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs/shared_experts_fusion", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", config_file_path)
Expand Down
18 changes: 16 additions & 2 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
share_fusion: int = 0,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is a little bit confusing. Is it identical with n_share_fusion_experts in the previous files? Why do we need a different name?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll handle it.

):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"

scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
Expand All @@ -122,9 +124,20 @@ def grouped_topk(
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if share_fusion:
topk_ids[:, -1] = torch.randint(low=num_experts,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether randint will be a little bit slower - shall we use something like round-robin

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can call torch.randint(.., out=topk_ids[...]) to save data copy

high=num_experts + share_fusion,
size=(topk_ids.size(0), ),
dtype=topk_ids.dtype,
device=topk_ids.device)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) * 1.0 / 2.5
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether a*1.0 will be optimized as a, or will pytorch actually call kernels to do a computation. If the latter, maybe we can remove *1.0 to speed up a tiny little bit


if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = topk_weights.sum(
dim=-1,
keepdim=True) if share_fusion == 0 else topk_weights[:, :-1].sum(
dim=-1, keepdim=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: another way to reduce a bit of code duplication may be

topk_weights_for_sum = topk_weights if if share_fusion == 0 else topk_weights[:, :-1]
topk_weights_sum = topk_weights_for_sum.sum(   dim=-1, keepdim=True)

topk_weights = topk_weights / topk_weights_sum

return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

Expand Down Expand Up @@ -210,7 +223,7 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
):
# DeekSeekv2 uses grouped_top_k
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
Expand All @@ -222,6 +235,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
share_fusion=int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0")) > 0,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems the env var is used several times, thus maybe we can extract it to a global variable (constant)

)
else:
topk_weights, topk_ids = biased_grouped_topk(
Expand Down
41 changes: 36 additions & 5 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Inference-only DeepseekV2 model."""

import os
from tqdm import tqdm
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0"))
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
Expand All @@ -187,8 +189,8 @@ def __init__(
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts + self.n_share_fusion_experts,
top_k=config.num_experts_per_tok + min(self.n_share_fusion_experts, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
Expand Down Expand Up @@ -256,8 +258,10 @@ def forward(
return self.forward_deepep(hidden_states, forward_mode)

def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None:
if self.n_shared_experts is not None and self.n_share_fusion_experts == 0:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = (
Expand Down Expand Up @@ -1309,6 +1313,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.n_share_fusion_experts = int(os.getenv("SHARE_EXPERTS_FUSION_REPLICA", "0"))
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
Expand Down Expand Up @@ -1342,7 +1347,33 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

if self.n_share_fusion_experts != 0:
weights_list = list(weights)
weights_dict = {k: v for (k, v) in weights_list}
suffix_list = [
'down_proj.weight', 'down_proj.weight_scale_inv',
'gate_proj.weight', 'gate_proj.weight_scale_inv',
'up_proj.weight', 'up_proj.weight_scale_inv'
]
current_device = torch.cuda.current_device()
is_master = (current_device == 0)
for moe_layer in tqdm(range(self.config.num_hidden_layers),
desc=f"Cloning {self.n_share_fusion_experts} "
"replicas of shared expert into MoE",
disable=not is_master):
if moe_layer < self.config.first_k_dense_replace:
continue
for num_repeat in range(self.n_share_fusion_experts):
for suffix in suffix_list:
weights_list.append((
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}", weights_dict[
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it possible that we remove the original shared_experts weight after this, then we save a bit of memory by making that shared_expert never load, and we can remove the logic of

if self.n_shared_experts is not None and self.n_share_fusion_experts == 0:
            shared_output = self.shared_experts(...)

above, since now it is directly None

f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"]
.clone()))
weights = weights_list

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = (
Expand All @@ -1354,7 +1385,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts + self.n_share_fusion_experts,
)

params_dict = dict(self.named_parameters())
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/tests/test_moe_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def moe_align_block_size_triton(
[32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256], # num_experts
[64, 160, 256, 257, 260], # num_experts
)
),
)
Expand Down
Loading