2727from vllm .attention import Attention
2828from vllm .attention .layers .chunked_local_attention import ChunkedLocalAttention
2929from vllm .compilation .decorators import support_torch_compile
30- from vllm .config import CacheConfig , VllmConfig
31- from vllm .distributed import get_tensor_model_parallel_world_size
30+ from vllm .config import CacheConfig , ParallelConfig , VllmConfig
31+ from vllm .distributed import get_ep_group , get_tensor_model_parallel_world_size
3232from vllm .model_executor .layers .fused_moe import FusedMoE
3333from vllm .model_executor .layers .layernorm import RMSNorm
3434from vllm .model_executor .layers .linear import (QKVParallelLinear ,
3939from vllm .model_executor .layers .shared_fused_moe import SharedFusedMoE
4040from vllm .model_executor .model_loader .weight_utils import (
4141 default_weight_loader , maybe_remap_kv_scale_name )
42+ from vllm .model_executor .models .interfaces import MixtureOfExperts
4243
4344from .llama import LlamaForCausalLM , LlamaMLP , LlamaModel
4445from .utils import (AutoWeightsLoader , extract_layer_index , fast_topk ,
@@ -62,10 +63,14 @@ def custom_routing_function(
6263 def __init__ (self ,
6364 config : Llama4TextConfig ,
6465 quant_config : Optional [QuantizationConfig ] = None ,
66+ parallel_config : Optional [ParallelConfig ] = None ,
6567 prefix : str = "" ):
6668 super ().__init__ ()
6769 self .tp_size = get_tensor_model_parallel_world_size ()
6870 self .top_k = config .num_experts_per_tok
71+ self .ep_group = get_ep_group ().device_group
72+ self .ep_rank = self .ep_group .rank ()
73+ self .ep_size = self .ep_group .size ()
6974
7075 intermediate_size_moe = config .intermediate_size
7176 self .router = ReplicatedLinear (config .hidden_size ,
@@ -84,6 +89,21 @@ def __init__(self,
8489 reduce_results = False ,
8590 )
8691
92+ # Load balancing settings.
93+ eplb_config = parallel_config .eplb_config if parallel_config else None
94+ self .enable_eplb = parallel_config .enable_eplb \
95+ if parallel_config else False
96+ self .n_redundant_experts = eplb_config .num_redundant_experts \
97+ if eplb_config else 0
98+
99+ self .n_routed_experts : int = config .num_local_experts
100+ self .n_logical_experts = self .n_routed_experts
101+ self .n_shared_experts : int = 1
102+ self .n_local_experts : int = config .num_local_experts
103+ self .n_physical_experts = (self .n_local_experts +
104+ self .n_redundant_experts )
105+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
106+
87107 self .experts = SharedFusedMoE (
88108 shared_experts = self .shared_expert ,
89109 num_experts = config .num_local_experts ,
@@ -96,6 +116,8 @@ def __init__(self,
96116 renormalize = False ,
97117 quant_config = quant_config ,
98118 prefix = f"{ prefix } .experts" ,
119+ enable_eplb = self .enable_eplb ,
120+ num_redundant_experts = self .n_redundant_experts ,
99121 )
100122
101123 def forward (self , hidden_states ):
@@ -262,6 +284,7 @@ def __init__(
262284 config : Llama4TextConfig ,
263285 cache_config : Optional [CacheConfig ] = None ,
264286 quant_config : Optional [QuantizationConfig ] = None ,
287+ parallel_config : Optional [ParallelConfig ] = None ,
265288 prefix : str = "" ,
266289 ) -> None :
267290 super ().__init__ ()
@@ -293,6 +316,7 @@ def __init__(
293316 self .feed_forward = Llama4MoE (
294317 config = config ,
295318 quant_config = quant_config ,
319+ parallel_config = parallel_config ,
296320 prefix = f"{ prefix } .feed_forward" ,
297321 )
298322 else :
@@ -641,7 +665,7 @@ def load_weights(self, weights: Iterable[tuple[str,
641665 return loaded_params
642666
643667
644- class Llama4ForCausalLM (LlamaForCausalLM ):
668+ class Llama4ForCausalLM (LlamaForCausalLM , MixtureOfExperts ):
645669
646670 packed_modules_mapping = {
647671 "qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
@@ -663,6 +687,66 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
663687 prefix = prefix ,
664688 layer_type = Llama4DecoderLayer )
665689
690+ self .expert_weights = []
691+
692+ self .moe_layers : list [FusedMoE ] = []
693+ example_moe = None
694+ for layer in self .model .layers :
695+ assert isinstance (layer , Llama4DecoderLayer )
696+ if isinstance (layer .feed_forward , Llama4MoE ):
697+ # Pick last one layer since the first ones may be dense layers.
698+ example_moe = layer .feed_forward
699+ self .moe_layers .append (layer .feed_forward .experts )
700+
701+ if example_moe is None :
702+ raise RuntimeError ("No Llama4MoE layer found in model.layers." )
703+
704+ # Set MoE hyperparameters
705+ self .num_moe_layers = len (self .moe_layers )
706+ print (f"num_moe_layers: { self .num_moe_layers } " )
707+ self .num_expert_groups = 1
708+ self .num_logical_experts = example_moe .n_logical_experts
709+ self .num_physical_experts = example_moe .n_physical_experts
710+ self .num_local_physical_experts = example_moe .n_local_physical_experts
711+ self .num_routed_experts = example_moe .n_routed_experts
712+ self .num_shared_experts = example_moe .n_shared_experts
713+ self .num_redundant_experts = example_moe .n_redundant_experts
714+
715+ def set_eplb_state (
716+ self ,
717+ expert_load_view : torch .Tensor ,
718+ logical_to_physical_map : torch .Tensor ,
719+ logical_replica_count : torch .Tensor ,
720+ ) -> None :
721+ for layer_idx , layer in enumerate (self .moe_layers ):
722+ # Register the expert weights.
723+ self .expert_weights .append (layer .get_expert_weights ())
724+ print (f"set eplb state layer_idx: { layer_idx } " )
725+ layer .set_eplb_state (
726+ moe_layer_idx = layer_idx ,
727+ expert_load_view = expert_load_view ,
728+ logical_to_physical_map = logical_to_physical_map ,
729+ logical_replica_count = logical_replica_count ,
730+ )
731+
732+ def update_physical_experts_metadata (
733+ self ,
734+ num_physical_experts : int ,
735+ num_local_physical_experts : int ,
736+ ) -> None :
737+ assert self .num_local_physical_experts == num_local_physical_experts
738+ self .num_physical_experts = num_physical_experts
739+ self .num_local_physical_experts = num_local_physical_experts
740+ self .num_redundant_experts = (num_physical_experts -
741+ self .num_logical_experts )
742+ for layer in self .model .layers :
743+ if isinstance (layer .feed_forward , Llama4MoE ):
744+ moe = layer .feed_forward
745+ moe .n_local_physical_experts = num_local_physical_experts
746+ moe .n_physical_experts = num_physical_experts
747+ moe .n_redundant_experts = self .num_redundant_experts
748+ moe .experts .update_expert_map ()
749+
666750 def _init_model (self ,
667751 vllm_config : VllmConfig ,
668752 prefix : str = "" ,
0 commit comments