@@ -842,60 +842,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
842842 self .num_expert_groups = config .n_group
843843
844844 self .moe_layers : list [FusedMoE ] = []
845- example_moe = None
845+ self . example_moe = None
846846 for layer in self .model .layers :
847847 if isinstance (layer , PPMissingLayer ):
848848 continue
849849
850850 assert isinstance (layer , DeepseekV2DecoderLayer )
851851 if isinstance (layer .mlp , DeepseekV2MoE ):
852852 # Pick last one layer since the first ones may be dense layers.
853- example_moe = layer .mlp
853+ self . example_moe = layer .mlp
854854 self .moe_layers .append (layer .mlp .experts )
855855
856- if example_moe is None :
856+ if self . example_moe is None :
857857 raise RuntimeError ("No DeepseekV2MoE layer found in model.layers." )
858858
859- self .num_logical_experts = example_moe .n_logical_experts
860- self .num_physical_experts = example_moe .n_physical_experts
861- self .num_local_physical_experts = example_moe .n_local_physical_experts
862- self .num_routed_experts = example_moe .n_routed_experts
863- self .num_shared_experts = example_moe .n_shared_experts
864- self .num_redundant_experts = example_moe .n_redundant_experts
865-
866- def set_eplb_state (
867- self ,
868- expert_load_view : torch .Tensor ,
869- logical_to_physical_map : torch .Tensor ,
870- logical_replica_count : torch .Tensor ,
871- ) -> None :
872- for layer_idx , layer in enumerate (self .moe_layers ):
873- # Register the expert weights.
874- self .expert_weights .append (layer .get_expert_weights ())
875- layer .set_eplb_state (
876- moe_layer_idx = layer_idx ,
877- expert_load_view = expert_load_view ,
878- logical_to_physical_map = logical_to_physical_map ,
879- logical_replica_count = logical_replica_count ,
880- )
881-
882- def update_physical_experts_metadata (
883- self ,
884- num_physical_experts : int ,
885- num_local_physical_experts : int ,
886- ) -> None :
887- assert self .num_local_physical_experts == num_local_physical_experts
888- self .num_physical_experts = num_physical_experts
889- self .num_local_physical_experts = num_local_physical_experts
890- self .num_redundant_experts = (num_physical_experts -
891- self .num_logical_experts )
892- for layer in self .model .layers :
893- if isinstance (layer .mlp , DeepseekV2MoE ):
894- moe = layer .mlp
895- moe .n_local_physical_experts = num_local_physical_experts
896- moe .n_physical_experts = num_physical_experts
897- moe .n_redundant_experts = self .num_redundant_experts
898- moe .experts .update_expert_map ()
859+ self .num_logical_experts = self .example_moe .n_logical_experts
860+ self .num_physical_experts = self .example_moe .n_physical_experts
861+ self .num_local_physical_experts = self .example_moe .n_local_physical_experts
862+ self .num_routed_experts = self .example_moe .n_routed_experts
863+ self .num_shared_experts = self .example_moe .n_shared_experts
864+ self .num_redundant_experts = self .example_moe .n_redundant_experts
899865
900866 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
901867 return self .model .get_input_embeddings (input_ids )
@@ -927,16 +893,10 @@ def load_weights(self, weights: Iterable[tuple[str,
927893 ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
928894 ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
929895 ]
930-
896+ from vllm . distributed . eplb . gpu_model_register import get_expert_mapping , load_expert_weight
931897 # Params for weights, fp8 weight scales, fp8 activation scales
932898 # (param_name, weight_name, expert_id, shard_id)
933- expert_params_mapping = FusedMoE .make_expert_params_mapping (
934- ckpt_gate_proj_name = "gate_proj" ,
935- ckpt_down_proj_name = "down_proj" ,
936- ckpt_up_proj_name = "up_proj" ,
937- num_experts = self .config .n_routed_experts ,
938- num_redundant_experts = self .num_redundant_experts )
939-
899+ expert_params_mapping = get_expert_mapping (self )
940900 params_dict = dict (self .named_parameters ())
941901 loaded_params : set [str ] = set ()
942902 for name , loaded_weight in weights :
@@ -982,34 +942,17 @@ def load_weights(self, weights: Iterable[tuple[str,
982942 break
983943 else :
984944 is_expert_weight = False
945+ is_continue = False
985946 for mapping in expert_params_mapping :
986- param_name , weight_name , expert_id , shard_id = mapping
987- if weight_name not in name :
988- continue
989-
990- # Anyway, this is an expert weight and should not be
991- # attempted to load as other weights later
992- is_expert_weight = True
993-
994- # Do not modify `name` since the loop may continue here
995- # Instead, create a new variable
996- name_mapped = name .replace (weight_name , param_name )
947+ expert_matched , is_continue , success , name_mapped = \
948+ load_expert_weight (self , mapping , name ,
949+ loaded_weight , params_dict )
950+ if expert_matched :
951+ is_expert_weight = True
997952
998- if is_pp_missing_parameter ( name_mapped , self ) :
953+ if is_continue :
999954 continue
1000955
1001- param = params_dict [name_mapped ]
1002- # We should ask the weight loader to return success or not
1003- # here since otherwise we may skip experts with other
1004- # available replicas.
1005- weight_loader = typing .cast (Callable [..., bool ],
1006- param .weight_loader )
1007- success = weight_loader (param ,
1008- loaded_weight ,
1009- name_mapped ,
1010- shard_id = shard_id ,
1011- expert_id = expert_id ,
1012- return_success = True )
1013956 if success :
1014957 name = name_mapped
1015958 break
0 commit comments