Skip to content

Commit 221363b

Browse files
authored
Merge pull request #2 from dsxsteven/model_register
add gpu_model_register
2 parents 47091ee + 1f7ead3 commit 221363b

File tree

5 files changed

+172
-219
lines changed

5 files changed

+172
-219
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import types
5+
import typing
6+
import torch
7+
from vllm.model_executor.layers.fused_moe import FusedMoE
8+
from vllm.model_executor.models.utils import is_pp_missing_parameter
9+
from typing import Callable
10+
11+
def set_eplb_state(
12+
self,
13+
expert_load_view: torch.Tensor,
14+
logical_to_physical_map: torch.Tensor,
15+
logical_replica_count: torch.Tensor,
16+
) -> None:
17+
for layer_idx, layer in enumerate(self.moe_layers):
18+
# Register the expert weights.
19+
self.expert_weights.append(layer.get_expert_weights())
20+
layer.set_eplb_state(
21+
moe_layer_idx=layer_idx,
22+
expert_load_view=expert_load_view,
23+
logical_to_physical_map=logical_to_physical_map,
24+
logical_replica_count=logical_replica_count,
25+
)
26+
27+
def update_physical_experts_metadata(
28+
self,
29+
num_physical_experts: int,
30+
num_local_physical_experts: int,
31+
) -> None:
32+
assert self.num_local_physical_experts == num_local_physical_experts
33+
self.num_physical_experts = num_physical_experts
34+
self.num_local_physical_experts = num_local_physical_experts
35+
self.num_redundant_experts = (num_physical_experts -
36+
self.num_logical_experts)
37+
for layer in self.model.layers:
38+
if isinstance(layer.mlp, self.example_moe):
39+
moe = layer.mlp
40+
moe.n_local_physical_experts = num_local_physical_experts
41+
moe.n_physical_experts = num_physical_experts
42+
moe.n_redundant_experts = self.num_redundant_experts
43+
moe.experts.update_expert_map()
44+
45+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
46+
# Params for weights, fp8 weight scales, fp8 activation scales
47+
# (param_name, weight_name, expert_id, shard_id)
48+
return FusedMoE.make_expert_params_mapping(
49+
ckpt_gate_proj_name="gate_proj",
50+
ckpt_down_proj_name="down_proj",
51+
ckpt_up_proj_name="up_proj",
52+
num_experts=self.config.n_routed_experts,
53+
num_redundant_experts=self.num_redundant_experts)
54+
55+
def load_expert_weight(self, mapping, name, loaded_weight, params_dict):
56+
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
57+
".v_scale", "_v_scale", ".weight_scale",
58+
"_weight_scale", ".input_scale", "_input_scale")
59+
60+
expert_matched = False
61+
is_continue = False
62+
success = False
63+
name_mapped = ''
64+
param_name, weight_name, expert_id, shard_id = mapping
65+
if weight_name not in name:
66+
is_continue = True
67+
return expert_matched, is_continue, success, name_mapped
68+
69+
# Anyway, this is an expert weight and should not be
70+
# attempted to load as other weights later
71+
expert_matched = True
72+
73+
# Do not modify `name` since the loop may continue here
74+
# Instead, create a new variable
75+
name_mapped = name.replace(weight_name, param_name)
76+
77+
if is_pp_missing_parameter(name_mapped, self):
78+
is_continue = True
79+
return expert_matched, is_continue, success, name_mapped
80+
81+
# Skip loading extra parameters for GPTQ/modelopt models.
82+
if name_mapped.endswith(ignore_suffixes) \
83+
and name_mapped not in params_dict:
84+
is_continue = True
85+
return expert_matched, is_continue, success, name_mapped
86+
87+
param = params_dict[name_mapped]
88+
# We should ask the weight loader to return success or not
89+
# here since otherwise we may skip experts with other
90+
# available replicas.
91+
weight_loader = typing.cast(Callable[..., bool],
92+
param.weight_loader)
93+
success = weight_loader(param,
94+
loaded_weight,
95+
name_mapped,
96+
shard_id=shard_id,
97+
expert_id=expert_id,
98+
return_success=True)
99+
return expert_matched, is_continue, success, name_mapped
100+
101+
def model_register(model):
102+
"""
103+
Registers custom methods related to Expert Parallel Load Balancing (EPLB)
104+
onto the vLLM model instance. It also determines the number of MoE layers
105+
based on the model configuration.
106+
107+
Args:
108+
model: The vLLM model instance to which the methods will be added.
109+
"""
110+
model.set_eplb_state = types.MethodType(set_eplb_state, model)
111+
model.load_expert_weight = types.MethodType(load_expert_weight, model)
112+
model.update_physical_experts_metadata = \
113+
types.MethodType(update_physical_experts_metadata, model)
114+
model.model.get_expert_mapping = \
115+
types.MethodType(get_expert_mapping, model.model)
116+
print("register complete")

vllm/model_executor/models/deepseek_v2.py

Lines changed: 18 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -842,60 +842,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
842842
self.num_expert_groups = config.n_group
843843

844844
self.moe_layers: list[FusedMoE] = []
845-
example_moe = None
845+
self.example_moe = None
846846
for layer in self.model.layers:
847847
if isinstance(layer, PPMissingLayer):
848848
continue
849849

850850
assert isinstance(layer, DeepseekV2DecoderLayer)
851851
if isinstance(layer.mlp, DeepseekV2MoE):
852852
# Pick last one layer since the first ones may be dense layers.
853-
example_moe = layer.mlp
853+
self.example_moe = layer.mlp
854854
self.moe_layers.append(layer.mlp.experts)
855855

856-
if example_moe is None:
856+
if self.example_moe is None:
857857
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
858858

859-
self.num_logical_experts = example_moe.n_logical_experts
860-
self.num_physical_experts = example_moe.n_physical_experts
861-
self.num_local_physical_experts = example_moe.n_local_physical_experts
862-
self.num_routed_experts = example_moe.n_routed_experts
863-
self.num_shared_experts = example_moe.n_shared_experts
864-
self.num_redundant_experts = example_moe.n_redundant_experts
865-
866-
def set_eplb_state(
867-
self,
868-
expert_load_view: torch.Tensor,
869-
logical_to_physical_map: torch.Tensor,
870-
logical_replica_count: torch.Tensor,
871-
) -> None:
872-
for layer_idx, layer in enumerate(self.moe_layers):
873-
# Register the expert weights.
874-
self.expert_weights.append(layer.get_expert_weights())
875-
layer.set_eplb_state(
876-
moe_layer_idx=layer_idx,
877-
expert_load_view=expert_load_view,
878-
logical_to_physical_map=logical_to_physical_map,
879-
logical_replica_count=logical_replica_count,
880-
)
881-
882-
def update_physical_experts_metadata(
883-
self,
884-
num_physical_experts: int,
885-
num_local_physical_experts: int,
886-
) -> None:
887-
assert self.num_local_physical_experts == num_local_physical_experts
888-
self.num_physical_experts = num_physical_experts
889-
self.num_local_physical_experts = num_local_physical_experts
890-
self.num_redundant_experts = (num_physical_experts -
891-
self.num_logical_experts)
892-
for layer in self.model.layers:
893-
if isinstance(layer.mlp, DeepseekV2MoE):
894-
moe = layer.mlp
895-
moe.n_local_physical_experts = num_local_physical_experts
896-
moe.n_physical_experts = num_physical_experts
897-
moe.n_redundant_experts = self.num_redundant_experts
898-
moe.experts.update_expert_map()
859+
self.num_logical_experts = self.example_moe.n_logical_experts
860+
self.num_physical_experts = self.example_moe.n_physical_experts
861+
self.num_local_physical_experts = self.example_moe.n_local_physical_experts
862+
self.num_routed_experts = self.example_moe.n_routed_experts
863+
self.num_shared_experts = self.example_moe.n_shared_experts
864+
self.num_redundant_experts = self.example_moe.n_redundant_experts
899865

900866
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
901867
return self.model.get_input_embeddings(input_ids)
@@ -927,16 +893,10 @@ def load_weights(self, weights: Iterable[tuple[str,
927893
("fused_qkv_a_proj", "q_a_proj", 0),
928894
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
929895
]
930-
896+
from vllm.distributed.eplb.gpu_model_register import get_expert_mapping, load_expert_weight
931897
# Params for weights, fp8 weight scales, fp8 activation scales
932898
# (param_name, weight_name, expert_id, shard_id)
933-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
934-
ckpt_gate_proj_name="gate_proj",
935-
ckpt_down_proj_name="down_proj",
936-
ckpt_up_proj_name="up_proj",
937-
num_experts=self.config.n_routed_experts,
938-
num_redundant_experts=self.num_redundant_experts)
939-
899+
expert_params_mapping = get_expert_mapping(self)
940900
params_dict = dict(self.named_parameters())
941901
loaded_params: set[str] = set()
942902
for name, loaded_weight in weights:
@@ -982,34 +942,17 @@ def load_weights(self, weights: Iterable[tuple[str,
982942
break
983943
else:
984944
is_expert_weight = False
945+
is_continue = False
985946
for mapping in expert_params_mapping:
986-
param_name, weight_name, expert_id, shard_id = mapping
987-
if weight_name not in name:
988-
continue
989-
990-
# Anyway, this is an expert weight and should not be
991-
# attempted to load as other weights later
992-
is_expert_weight = True
993-
994-
# Do not modify `name` since the loop may continue here
995-
# Instead, create a new variable
996-
name_mapped = name.replace(weight_name, param_name)
947+
expert_matched, is_continue, success, name_mapped = \
948+
load_expert_weight(self, mapping, name,
949+
loaded_weight, params_dict)
950+
if expert_matched:
951+
is_expert_weight = True
997952

998-
if is_pp_missing_parameter(name_mapped, self):
953+
if is_continue:
999954
continue
1000955

1001-
param = params_dict[name_mapped]
1002-
# We should ask the weight loader to return success or not
1003-
# here since otherwise we may skip experts with other
1004-
# available replicas.
1005-
weight_loader = typing.cast(Callable[..., bool],
1006-
param.weight_loader)
1007-
success = weight_loader(param,
1008-
loaded_weight,
1009-
name_mapped,
1010-
shard_id=shard_id,
1011-
expert_id=expert_id,
1012-
return_success=True)
1013956
if success:
1014957
name = name_mapped
1015958
break

vllm/model_executor/models/glm4_moe.py

Lines changed: 18 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,6 @@ def make_empty_intermediate_tensors(
495495
device=device),
496496
})
497497

498-
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
499-
# Params for weights, fp8 weight scales, fp8 activation scales
500-
# (param_name, weight_name, expert_id, shard_id)
501-
return FusedMoE.make_expert_params_mapping(
502-
ckpt_gate_proj_name="gate_proj",
503-
ckpt_down_proj_name="down_proj",
504-
ckpt_up_proj_name="up_proj",
505-
num_experts=self.config.n_routed_experts)
506-
507498
def load_weights(self, weights: Iterable[tuple[str,
508499
torch.Tensor]]) -> set[str]:
509500
stacked_params_mapping = [
@@ -514,10 +505,10 @@ def load_weights(self, weights: Iterable[tuple[str,
514505
("gate_up_proj", "gate_proj", 0),
515506
("gate_up_proj", "up_proj", 1),
516507
]
517-
508+
from vllm.distributed.eplb.gpu_model_register import get_expert_mapping, load_expert_weight
518509
params_dict = dict(self.named_parameters())
519510
loaded_params: set[str] = set()
520-
expert_params_mapping = self.get_expert_mapping()
511+
expert_params_mapping = get_expert_mapping(self)
521512
for name, loaded_weight in weights:
522513
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
523514
if spec_layer is not None:
@@ -547,34 +538,17 @@ def load_weights(self, weights: Iterable[tuple[str,
547538
break
548539
else:
549540
is_expert_weight = False
541+
is_continue = False
550542
for mapping in expert_params_mapping:
551-
param_name, weight_name, expert_id, shard_id = mapping
552-
if weight_name not in name:
553-
continue
554-
555-
# Anyway, this is an expert weight and should not be
556-
# attempted to load as other weights later
557-
is_expert_weight = True
543+
expert_matched, is_continue, success, name_mapped = \
544+
load_expert_weight(self, mapping, name,
545+
loaded_weight, params_dict)
546+
if expert_matched:
547+
is_expert_weight = True
558548

559-
# Do not modify `name` since the loop may continue here
560-
# Instead, create a new variable
561-
name_mapped = name.replace(weight_name, param_name)
562-
563-
if is_pp_missing_parameter(name_mapped, self):
549+
if is_continue:
564550
continue
565551

566-
param = params_dict[name_mapped]
567-
# We should ask the weight loader to return success or not
568-
# here since otherwise we may skip experts with other
569-
# available replicas.
570-
weight_loader = typing.cast(Callable[..., bool],
571-
param.weight_loader)
572-
success = weight_loader(param,
573-
loaded_weight,
574-
name_mapped,
575-
shard_id=shard_id,
576-
expert_id=expert_id,
577-
return_success=True)
578552
if success:
579553
name = name_mapped
580554
break
@@ -648,42 +622,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
648622
self.num_expert_groups = config.n_group
649623

650624
self.moe_layers: list[FusedMoE] = []
651-
example_moe = None
625+
self.example_moe = None
652626
for layer in self.model.layers:
653627
if isinstance(layer, PPMissingLayer):
654628
continue
655629

656630
assert isinstance(layer, Glm4MoeDecoderLayer)
657631
if isinstance(layer.mlp, Glm4MoE):
658632
# Pick last one layer since the first ones may be dense layers.
659-
example_moe = layer.mlp
633+
self.example_moe = layer.mlp
660634
self.moe_layers.append(layer.mlp.experts)
661635

662-
if example_moe is None:
636+
if self.example_moe is None:
663637
raise RuntimeError("No Glm4MoE layer found in model.layers.")
664638

665-
self.num_logical_experts = example_moe.n_logical_experts
666-
self.num_physical_experts = example_moe.n_physical_experts
667-
self.num_local_physical_experts = example_moe.n_local_physical_experts
668-
self.num_routed_experts = example_moe.n_routed_experts
669-
self.num_shared_experts = example_moe.n_shared_experts
670-
self.num_redundant_experts = example_moe.n_redundant_experts
671-
672-
def set_eplb_state(
673-
self,
674-
expert_load_view: torch.Tensor,
675-
logical_to_physical_map: torch.Tensor,
676-
logical_replica_count: torch.Tensor,
677-
) -> None:
678-
for layer_idx, layer in enumerate(self.moe_layers):
679-
# Register the expert weights.
680-
self.expert_weights.append(layer.get_expert_weights())
681-
layer.set_eplb_state(
682-
moe_layer_idx=layer_idx,
683-
expert_load_view=expert_load_view,
684-
logical_to_physical_map=logical_to_physical_map,
685-
logical_replica_count=logical_replica_count,
686-
)
639+
self.num_logical_experts = self.example_moe.n_logical_experts
640+
self.num_physical_experts = self.example_moe.n_physical_experts
641+
self.num_local_physical_experts = self.example_moe.n_local_physical_experts
642+
self.num_routed_experts = self.example_moe.n_routed_experts
643+
self.num_shared_experts = self.example_moe.n_shared_experts
644+
self.num_redundant_experts = self.example_moe.n_redundant_experts
687645

688646
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
689647
return self.model.get_input_embeddings(input_ids)

0 commit comments

Comments
 (0)