Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/sglang/srt/epmoe_permute_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# python3 ./sglang_rebalance/python/sglang/srt/epmoe_permute_tensor.py
import torch
# Set a seed for reproducibility
torch.manual_seed(42)

# Create a global tensor variable for testing
# 61 rows, each row has a random permutation of integers from 0 to 255
EP_PERMUTE_TENSOR = torch.stack([
torch.randperm(256)
for _ in range(61)
], dim=0)

EP_BACK_MAPPING_TENSOR = torch.zeros((61, 256), dtype=torch.long)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better pass the num_layers info as well as num_experts.

for layer_idx in range(61):
for expert_idx, permuted_expert_id in enumerate(EP_PERMUTE_TENSOR[layer_idx]):
EP_BACK_MAPPING_TENSOR[layer_idx, permuted_expert_id] = expert_idx

# # Save the tensors to a text file
# with open("ep_permute_tensors.txt", "w") as f:
# f.write("EP_PERMUTE_TENSOR:\n")
# # Save the full tensor without truncation
# torch.set_printoptions(threshold=float('inf'))
# f.write(str(EP_PERMUTE_TENSOR))
# f.write("\n\nEP_BACK_MAPPING_TENSOR:\n")
# f.write(str(EP_BACK_MAPPING_TENSOR))
# # Reset print options to default
# torch.set_printoptions(threshold=1000)
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
ep_back_mapping_tensor: Optional[torch.Tensor] = None,
):
super().__init__()

Expand Down Expand Up @@ -202,6 +203,8 @@ def __init__(

self.grouped_gemm_runner = None

self.ep_back_mapping_tensor = ep_back_mapping_tensor.to(torch.cuda.current_device()) if ep_back_mapping_tensor is not None else None

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None

Expand All @@ -222,6 +225,9 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
)

if self.ep_back_mapping_tensor is not None:
topk_ids = self.ep_back_mapping_tensor[topk_ids]

reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
topk_ids, self.num_experts
Expand Down Expand Up @@ -410,6 +416,7 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
expert_id = self.ep_back_mapping_tensor[expert_id] if self.ep_back_mapping_tensor is not None else expert_id
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
return
expert_id = expert_id - self.start_expert_id
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
from sglang.srt.epmoe_permute_tensor import EP_PERMUTE_TENSOR, EP_BACK_MAPPING_TENSOR

if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
Expand All @@ -81,6 +82,9 @@
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"enable_eplb_moe": ServerArgs.enable_eplb_moe,
"ep_load_tensor": EP_PERMUTE_TENSOR,
"ep_back_mapping_tensor": EP_BACK_MAPPING_TENSOR,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
}
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
"enable_eplb_moe": server_args.enable_eplb_moe,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
Expand Down
26 changes: 19 additions & 7 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
ep_back_mapping_tensor: Optional[torch.Tensor] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -198,7 +199,7 @@ def __init__(
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
else (EPMoE if global_server_args_dict["enable_ep_moe"] or global_server_args_dict["enable_eplb_moe"] else FusedMoE)
)

self.experts = MoEImpl(
Expand All @@ -213,6 +214,7 @@ def __init__(
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
ep_back_mapping_tensor=ep_back_mapping_tensor if global_server_args_dict["enable_eplb_moe"] else None,
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
Expand Down Expand Up @@ -291,6 +293,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
print(f"final_hidden_states: {final_hidden_states}")
return final_hidden_states

def forward_deepep(
Expand Down Expand Up @@ -1081,12 +1084,21 @@ def is_sparse_layer(l: int):
)

if is_nextn or is_sparse_layer(layer_id):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = True
if global_server_args_dict["enable_eplb_moe"]:
ep_back_mapping_tensor = global_server_args_dict["ep_back_mapping_tensor"][layer_id]
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
ep_back_mapping_tensor=ep_back_mapping_tensor,
)
else:
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = True
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class ServerArgs:
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_eplb_moe: bool = False

deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
Expand Down Expand Up @@ -320,6 +322,12 @@ def __post_init__(self):
logger.info(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Load Balance MoE
if self.enable_eplb_moe:
self.ep_size = self.tp_size
logger.info(
f"Load Balance MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

# Speculative Decoding
if self.speculative_algorithm == "NEXTN":
Expand Down Expand Up @@ -1097,6 +1105,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-eplb-moe",
action="store_true",
help="Enabling Load Balance MoE implementation for EP MoE.",
)
parser.add_argument(
"--deepep-mode",
type=str,
Expand Down