3535 get_tensor_model_parallel_world_size ,
3636 tensor_model_parallel_all_reduce )
3737from vllm .model_executor .layers .activation import SiluAndMul
38- from vllm .model_executor .layers .fused_moe import fused_moe
3938from vllm .model_executor .layers .layernorm import RMSNorm
40- from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
39+ from vllm .model_executor .layers .linear import (FusedMoELinear ,
40+ MergedColumnParallelLinear ,
4141 QKVParallelLinear ,
4242 ReplicatedLinear ,
4343 RowParallelLinear )
@@ -93,25 +93,22 @@ def __init__(
9393 quant_config : Optional [QuantizationConfig ] = None ,
9494 ):
9595 super ().__init__ ()
96- self .config = config
97- self .rank = get_tensor_model_parallel_rank ()
98- self .tp_size = get_tensor_model_parallel_world_size ()
99- self .n_routed_experts = config .num_experts
100- self .top_k = config .num_experts_per_tok
101- if self .tp_size > self .n_routed_experts :
96+ self .tp_size = get_tensor_model_parallel_world_size ()
97+
98+ if self .tp_size > config .num_experts :
10299 raise ValueError (
103100 f"Tensor parallel size { self .tp_size } is greater than "
104- f"the number of experts { self . n_routed_experts } ." )
105-
106- self .experts = nn . ModuleList ([
107- Qwen2MoeMLP ( hidden_size = config .hidden_size ,
108- intermediate_size = config .moe_intermediate_size ,
109- hidden_act = config .hidden_act ,
110- quant_config = quant_config ,
111- reduce_results = False )
112- for idx in range ( self . n_routed_experts )
113- ])
114- self . pack_params ( )
101+ f"the number of experts { config . num_experts } ." )
102+
103+ self .experts = FusedMoELinear (
104+ num_experts = config .num_experts ,
105+ top_k = config .num_experts_per_tok ,
106+ hidden_size = config .hidden_size ,
107+ intermediate_size = config . moe_intermediate_size ,
108+ reduce_results = False ,
109+ renormalize = config . norm_topk_prob ,
110+ quant_config = quant_config ,
111+ )
115112
116113 self .gate = ReplicatedLinear (config .hidden_size ,
117114 self .n_routed_experts ,
@@ -131,25 +128,6 @@ def __init__(
131128 1 ,
132129 bias = False )
133130
134- def pack_params (self ):
135- w1 = []
136- w2 = []
137- for expert in self .experts :
138- w1 .append (expert .gate_up_proj .weight )
139- w2 .append (expert .down_proj .weight )
140- self .w1 = torch ._utils ._flatten_dense_tensors (w1 )
141- w1s = torch ._utils ._unflatten_dense_tensors (self .w1 , w1 )
142- for data , param in zip (w1s , w1 ):
143- param .data = data
144- self .w1 = self .w1 .view (len (w1 ), * w1s [0 ].shape )
145-
146- self .w2 = torch ._utils ._flatten_dense_tensors (w2 )
147- w2s = torch ._utils ._unflatten_dense_tensors (self .w2 , w2 )
148- for data , param in zip (w2s , w2 ):
149- param .data = data
150-
151- self .w2 = self .w2 .view (len (w2 ), * w2s [0 ].shape )
152-
153131 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
154132 num_tokens , hidden_dim = hidden_states .shape
155133 hidden_states = hidden_states .view (- 1 , hidden_dim )
@@ -162,18 +140,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162140
163141 # router_logits: (num_tokens, n_experts)
164142 router_logits , _ = self .gate (hidden_states )
165- final_hidden_states = fused_moe (hidden_states ,
166- self .w1 ,
167- self .w2 ,
168- router_logits ,
169- self .top_k ,
170- renormalize = self .config .norm_topk_prob ,
171- inplace = True )
172-
143+ final_hidden_states = self .experts (hidden_states = hidden_states ,
144+ router_logits = router_logits )
173145 if shared_output is not None :
174146 final_hidden_states = final_hidden_states + shared_output
175- final_hidden_states = tensor_model_parallel_all_reduce (
176- final_hidden_states )
147+ if self .tp_size > 1 :
148+ final_hidden_states = tensor_model_parallel_all_reduce (
149+ final_hidden_states )
177150
178151 return final_hidden_states .view (num_tokens , hidden_dim )
179152
@@ -284,6 +257,7 @@ def __init__(
284257 cache_config = cache_config ,
285258 quant_config = quant_config ,
286259 )
260+
287261 if (layer_idx not in config .mlp_only_layers ) and (
288262 config .num_experts > 0 and
289263 (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
@@ -426,21 +400,35 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
426400 ("gate_up_proj" , "up_proj" , 1 ),
427401 ]
428402
403+ expert_params_mapping = [
404+ # These are the weights for the experts
405+ # (param_name, weight_name, expert_id, shard_id)
406+ ("experts.w13_weight" if weight_name in ["gate_proj" , "up_proj" ] else "experts.w2_weight" ,
407+ f"experts.{ expert_id } .{ weight_name } .weight" , expert_id , shard_id )
408+ for expert_id in range (self .config .num_experts )
409+ for shard_id , weight_name in enumerate (["gate_proj" , "down_proj" , "up_proj" ])
410+ ]
411+
429412 params_dict = dict (self .named_parameters ())
430413 for name , loaded_weight in weights :
431414 if "rotary_emb.inv_freq" in name :
432415 continue
433416 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
417+ # Skip non-stacked and experts (experts handled below).
434418 if weight_name not in name :
435419 continue
420+ # We have mlp.experts[0].gate_proj in the checkpoint.
421+ # Since we handle the experts below in expert_params_mapping,
422+ # we need to skip here BEFORE we update the name, otherwise
423+ # name will be updated to mlp.experts[0].gate_up_proj, which
424+ # will then be updated below in expert_params_mapping
425+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
426+ if "mlp.experts" in name :
427+ continue
436428 name = name .replace (weight_name , param_name )
437429 # Skip loading extra bias for GPTQ models.
438430 if name .endswith (".bias" ) and name not in params_dict :
439431 continue
440- # Skip experts that are not assigned to this worker.
441- if (("mlp.experts." in name or "mlp.shared_expert." in name )
442- and name not in params_dict ):
443- continue
444432 if name not in params_dict :
445433 continue
446434
@@ -449,17 +437,26 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
449437 weight_loader (param , loaded_weight , shard_id )
450438 break
451439 else :
452- # Skip loading extra bias for GPTQ models.
453- if name .endswith (".bias" ) and name not in params_dict :
454- continue
455- # Skip experts that are not assigned to this worker.
456- if (("mlp.experts." in name or "mlp.shared_expert." in name )
457- and name not in params_dict ):
458- continue
459- if name not in params_dict :
460- continue
461-
462- param = params_dict [name ]
463- weight_loader = getattr (param , "weight_loader" ,
464- default_weight_loader )
465- weight_loader (param , loaded_weight )
440+ for param_name , weight_name , expert_id , shard_id in expert_params_mapping :
441+ if weight_name not in name :
442+ continue
443+ name = name .replace (weight_name , param_name )
444+ param = params_dict [name ]
445+ weight_loader = param .weight_loader
446+ weight_loader (param ,
447+ loaded_weight ,
448+ weight_name ,
449+ shard_id = shard_id ,
450+ expert_id = expert_id )
451+ break
452+ else :
453+ # Skip loading extra bias for GPTQ models.
454+ if name .endswith (".bias" ) and name not in params_dict :
455+ continue
456+ if name not in params_dict :
457+ continue
458+
459+ param = params_dict [name ]
460+ weight_loader = getattr (param , "weight_loader" ,
461+ default_weight_loader )
462+ weight_loader (param , loaded_weight )
0 commit comments