Skip to content
24 changes: 24 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()

self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.tp_size = (
Expand Down Expand Up @@ -374,6 +376,28 @@ def weight_loader(
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
physical_expert_ids = (
get_global_expert_location_metadata().logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)

def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,7 @@ def determine_num_fused_shared_experts(

if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
Expand Down
Loading