Skip to content

Commit cda869d

Browse files
committed
Add eplb support to Llama4
Signed-off-by: ilmarkov <[email protected]>
1 parent 98395a6 commit cda869d

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from vllm.attention import Attention, AttentionType
3535
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
3636
from vllm.compilation.decorators import support_torch_compile
37-
from vllm.config import CacheConfig, VllmConfig
37+
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
3838
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3939
from vllm.model_executor.layers.activation import SiluAndMul
4040
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -242,6 +242,7 @@ def __init__(
242242
config: LlamaConfig,
243243
cache_config: Optional[CacheConfig] = None,
244244
quant_config: Optional[QuantizationConfig] = None,
245+
parallel_config: Optional[ParallelConfig] = None,
245246
prefix: str = "",
246247
) -> None:
247248
super().__init__()
@@ -338,6 +339,7 @@ def __init__(self,
338339
cache_config = vllm_config.cache_config
339340
quant_config = vllm_config.quant_config
340341
lora_config = vllm_config.lora_config
342+
parallel_config = vllm_config.parallel_config
341343

342344
self.config = config
343345
self.quant_config = quant_config
@@ -360,6 +362,7 @@ def __init__(self,
360362
lambda prefix: layer_type(config=config,
361363
cache_config=cache_config,
362364
quant_config=quant_config,
365+
parallel_config=parallel_config,
363366
prefix=prefix),
364367
prefix=f"{prefix}.layers",
365368
)

vllm/model_executor/models/llama4.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from vllm.attention import Attention
2828
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
2929
from vllm.compilation.decorators import support_torch_compile
30-
from vllm.config import CacheConfig, VllmConfig
31-
from vllm.distributed import get_tensor_model_parallel_world_size
30+
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
31+
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
3232
from vllm.model_executor.layers.fused_moe import FusedMoE
3333
from vllm.model_executor.layers.layernorm import RMSNorm
3434
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
4040
from vllm.model_executor.model_loader.weight_utils import (
4141
default_weight_loader, maybe_remap_kv_scale_name)
42+
from vllm.model_executor.models.interfaces import MixtureOfExperts
4243

4344
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
4445
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@@ -62,10 +63,14 @@ def custom_routing_function(
6263
def __init__(self,
6364
config: Llama4TextConfig,
6465
quant_config: Optional[QuantizationConfig] = None,
66+
parallel_config: Optional[ParallelConfig] = None,
6567
prefix: str = ""):
6668
super().__init__()
6769
self.tp_size = get_tensor_model_parallel_world_size()
6870
self.top_k = config.num_experts_per_tok
71+
self.ep_group = get_ep_group().device_group
72+
self.ep_rank = self.ep_group.rank()
73+
self.ep_size = self.ep_group.size()
6974

7075
intermediate_size_moe = config.intermediate_size
7176
self.router = ReplicatedLinear(config.hidden_size,
@@ -84,6 +89,21 @@ def __init__(self,
8489
reduce_results=False,
8590
)
8691

92+
# Load balancing settings.
93+
eplb_config = parallel_config.eplb_config if parallel_config else None
94+
self.enable_eplb = parallel_config.enable_eplb \
95+
if parallel_config else False
96+
self.n_redundant_experts = eplb_config.num_redundant_experts \
97+
if eplb_config else 0
98+
99+
self.n_routed_experts: int = config.num_local_experts
100+
self.n_logical_experts = self.n_routed_experts
101+
self.n_shared_experts: int = 1
102+
self.n_local_experts: int = config.num_local_experts
103+
self.n_physical_experts = (self.n_local_experts +
104+
self.n_redundant_experts)
105+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
106+
87107
self.experts = SharedFusedMoE(
88108
shared_experts=self.shared_expert,
89109
num_experts=config.num_local_experts,
@@ -96,6 +116,8 @@ def __init__(self,
96116
renormalize=False,
97117
quant_config=quant_config,
98118
prefix=f"{prefix}.experts",
119+
enable_eplb=self.enable_eplb,
120+
num_redundant_experts=self.n_redundant_experts,
99121
)
100122

101123
def forward(self, hidden_states):
@@ -262,6 +284,7 @@ def __init__(
262284
config: Llama4TextConfig,
263285
cache_config: Optional[CacheConfig] = None,
264286
quant_config: Optional[QuantizationConfig] = None,
287+
parallel_config: Optional[ParallelConfig] = None,
265288
prefix: str = "",
266289
) -> None:
267290
super().__init__()
@@ -293,6 +316,7 @@ def __init__(
293316
self.feed_forward = Llama4MoE(
294317
config=config,
295318
quant_config=quant_config,
319+
parallel_config=parallel_config,
296320
prefix=f"{prefix}.feed_forward",
297321
)
298322
else:
@@ -641,7 +665,7 @@ def load_weights(self, weights: Iterable[tuple[str,
641665
return loaded_params
642666

643667

644-
class Llama4ForCausalLM(LlamaForCausalLM):
668+
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
645669

646670
packed_modules_mapping = {
647671
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -663,6 +687,66 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
663687
prefix=prefix,
664688
layer_type=Llama4DecoderLayer)
665689

690+
self.expert_weights = []
691+
692+
self.moe_layers: list[FusedMoE] = []
693+
example_moe = None
694+
for layer in self.model.layers:
695+
assert isinstance(layer, Llama4DecoderLayer)
696+
if isinstance(layer.feed_forward, Llama4MoE):
697+
# Pick last one layer since the first ones may be dense layers.
698+
example_moe = layer.feed_forward
699+
self.moe_layers.append(layer.feed_forward.experts)
700+
701+
if example_moe is None:
702+
raise RuntimeError("No Llama4MoE layer found in model.layers.")
703+
704+
# Set MoE hyperparameters
705+
self.num_moe_layers = len(self.moe_layers)
706+
print(f"num_moe_layers: {self.num_moe_layers}")
707+
self.num_expert_groups = 1
708+
self.num_logical_experts = example_moe.n_logical_experts
709+
self.num_physical_experts = example_moe.n_physical_experts
710+
self.num_local_physical_experts = example_moe.n_local_physical_experts
711+
self.num_routed_experts = example_moe.n_routed_experts
712+
self.num_shared_experts = example_moe.n_shared_experts
713+
self.num_redundant_experts = example_moe.n_redundant_experts
714+
715+
def set_eplb_state(
716+
self,
717+
expert_load_view: torch.Tensor,
718+
logical_to_physical_map: torch.Tensor,
719+
logical_replica_count: torch.Tensor,
720+
) -> None:
721+
for layer_idx, layer in enumerate(self.moe_layers):
722+
# Register the expert weights.
723+
self.expert_weights.append(layer.get_expert_weights())
724+
print(f"set eplb state layer_idx: {layer_idx}")
725+
layer.set_eplb_state(
726+
moe_layer_idx=layer_idx,
727+
expert_load_view=expert_load_view,
728+
logical_to_physical_map=logical_to_physical_map,
729+
logical_replica_count=logical_replica_count,
730+
)
731+
732+
def update_physical_experts_metadata(
733+
self,
734+
num_physical_experts: int,
735+
num_local_physical_experts: int,
736+
) -> None:
737+
assert self.num_local_physical_experts == num_local_physical_experts
738+
self.num_physical_experts = num_physical_experts
739+
self.num_local_physical_experts = num_local_physical_experts
740+
self.num_redundant_experts = (num_physical_experts -
741+
self.num_logical_experts)
742+
for layer in self.model.layers:
743+
if isinstance(layer.feed_forward, Llama4MoE):
744+
moe = layer.feed_forward
745+
moe.n_local_physical_experts = num_local_physical_experts
746+
moe.n_physical_experts = num_physical_experts
747+
moe.n_redundant_experts = self.num_redundant_experts
748+
moe.experts.update_expert_map()
749+
666750
def _init_model(self,
667751
vllm_config: VllmConfig,
668752
prefix: str = "",

0 commit comments

Comments
 (0)