@@ -24,15 +24,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
2424 raise NotImplementedError
2525
2626 @abstractmethod
27- def apply (self ,
28- layer : torch .nn .Module ,
29- x : torch .Tensor ,
30- router_logits : torch .Tensor ,
31- top_k : int ,
32- renormalize : bool = True ,
33- use_grouped_topk : bool = False ,
34- num_expert_group : Optional [int ] = None ,
35- topk_group : Optional [int ] = None ) -> torch .Tensor :
27+ def apply (self , layer : torch .nn .Module , x : torch .Tensor ,
28+ router_logits : torch .Tensor , top_k : int , renormalize : bool ,
29+ use_grouped_topk : bool ) -> torch .Tensor :
3630 raise NotImplementedError
3731
3832
@@ -61,66 +55,78 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
6155 layer .register_parameter ("w2_weight" , w2_weight )
6256 set_weight_attrs (w2_weight , extra_weight_attrs )
6357
64- def apply (
65- self ,
66- layer : torch .nn .Module ,
67- x : torch .Tensor ,
68- router_logits : torch .Tensor ,
69- top_k : int ,
70- renormalize : bool = True ,
71- use_grouped_topk : bool = False ,
72- num_expert_group : Optional [int ] = None ,
73- topk_group : Optional [int ] = None ,
74- ) -> torch .Tensor :
75- return self .forward (x , layer .w13_weight , layer .w2_weight ,
76- router_logits , top_k , renormalize ,
77- use_grouped_topk , num_expert_group , topk_group )
78-
79- def forward_cuda (
80- self ,
81- x : torch .Tensor ,
82- w1 : torch .Tensor ,
83- w2 : torch .Tensor ,
84- router_logits : torch .Tensor ,
85- top_k : int ,
86- renormalize : bool ,
87- use_grouped_topk : bool ,
88- num_expert_group : Optional [int ],
89- topk_group : Optional [int ],
90- ) -> torch .Tensor :
91- from vllm .model_executor .layers .fused_moe .fused_moe import fused_moe
92- return fused_moe (x ,
93- w1 ,
94- w2 ,
95- router_logits ,
96- top_k ,
97- renormalize = renormalize ,
98- inplace = True ,
99- use_grouped_topk = use_grouped_topk ,
100- num_expert_group = num_expert_group ,
101- topk_group = topk_group )
58+ def apply (self ,
59+ layer : torch .nn .Module ,
60+ x : torch .Tensor ,
61+ router_logits : torch .Tensor ,
62+ top_k : int ,
63+ renormalize : bool ,
64+ use_grouped_topk : bool ,
65+ topk_group : Optional [int ] = None ,
66+ num_expert_group : Optional [int ] = None ) -> torch .Tensor :
67+
68+ return self .forward (x = x ,
69+ layer = layer ,
70+ router_logits = router_logits ,
71+ top_k = top_k ,
72+ renormalize = renormalize ,
73+ use_grouped_topk = use_grouped_topk ,
74+ topk_group = topk_group ,
75+ num_expert_group = num_expert_group )
76+
77+ def forward_cuda (self ,
78+ layer : torch .nn .Module ,
79+ x : torch .Tensor ,
80+ use_grouped_topk : bool ,
81+ top_k : int ,
82+ router_logits : torch .Tensor ,
83+ renormalize : bool ,
84+ topk_group : Optional [int ] = None ,
85+ num_expert_group : Optional [int ] = None ) -> torch .Tensor :
86+
87+ from vllm .model_executor .layers .fused_moe .fused_moe import (
88+ fused_experts )
89+
90+ topk_weights , topk_ids = FusedMoE .select_experts (
91+ hidden_states = x ,
92+ router_logits = router_logits ,
93+ use_grouped_topk = use_grouped_topk ,
94+ top_k = top_k ,
95+ renormalize = renormalize ,
96+ topk_group = topk_group ,
97+ num_expert_group = num_expert_group )
98+
99+ return fused_experts (hidden_states = x ,
100+ w1 = layer .w13_weight ,
101+ w2 = layer .w2_weight ,
102+ topk_weights = topk_weights ,
103+ topk_ids = topk_ids ,
104+ inplace = True )
102105
103106 def forward_cpu (self , * args , ** kwargs ):
104107 raise NotImplementedError (
105108 "The CPU backend currently does not support MoE." )
106109
107- def forward_tpu (
108- self ,
109- x : torch .Tensor ,
110- w1 : torch .Tensor ,
111- w2 : torch .Tensor ,
112- router_logits : torch .Tensor ,
113- top_k : int ,
114- renormalize : bool ,
115- use_grouped_topk : bool ,
116- num_expert_group : Optional [int ],
117- topk_group : Optional [int ],
118- ) -> torch .Tensor :
110+ def forward_tpu (self ,
111+ layer : torch .nn .Module ,
112+ x : torch .Tensor ,
113+ use_grouped_topk : bool ,
114+ top_k : int ,
115+ router_logits : torch .Tensor ,
116+ renormalize : bool ,
117+ topk_group : Optional [int ] = None ,
118+ num_expert_group : Optional [int ] = None ) -> torch .Tensor :
119+
119120 from vllm .model_executor .layers .fused_moe .moe_pallas import fused_moe
120121 assert not use_grouped_topk
121122 assert num_expert_group is None
122123 assert topk_group is None
123- return fused_moe (x , w1 , w2 , router_logits , top_k , renormalize )
124+ return fused_moe (hidden_states = x ,
125+ w1 = layer .w13_weight ,
126+ w2 = layer .w2_weight ,
127+ topk = top_k ,
128+ gating_output = router_logits ,
129+ renormalize = renormalize )
124130
125131
126132class FusedMoE (torch .nn .Module ):
@@ -195,67 +201,98 @@ def __init__(
195201
196202 def weight_loader (self , param : torch .nn .Parameter ,
197203 loaded_weight : torch .Tensor , weight_name : str ,
198- shard_id : int , expert_id : int ):
199- param_data = param .data
200-
201- # Input scales can be loaded directly and should be equal.
202- if "input_scale" in weight_name :
203- if param_data [expert_id ] != 1 and (param_data [expert_id ] -
204- loaded_weight ).abs () > 1e-5 :
205- raise ValueError (
206- "input_scales of w1 and w3 of a layer "
207- f"must be equal. But got { param_data [expert_id ]} "
208- f"vs. { loaded_weight } " )
209- param_data [expert_id ] = loaded_weight
210- # Weight scales
211- elif "weight_scale" in weight_name :
212- # If we are in merged column case (gate_up_proj)
213- # shard_id 0 == gate_proj / w1
214- # shard_id 2 == up_proj / w3
215- if shard_id == 0 or shard_id == 2 :
216- # We have to keep the weight scales of w1 and w3 because
217- # we need to re-quantize w1/w3 weights after weight loading.
218- idx = 0 if shard_id == 0 else 1
219- param_data [expert_id ][idx ] = loaded_weight
220- # If we are in the row parallel case (down_proj)
221- # shard_id 1 == down_proj / w2
222- else :
223- param_data [expert_id ] = loaded_weight
224- # Weights
204+ shard_id : str , expert_id : int ) -> None :
205+ if shard_id not in ("w1" , "w2" , "w3" ):
206+ raise ValueError (f"shard_id must be ['w1','w2','w3'] but "
207+ f"got { shard_id } ." )
208+
209+ # Special case for fp8 scales.
210+ if getattr (param , "is_fp8_scale" , False ):
211+ self ._load_fp8_scale (param .data , loaded_weight , weight_name ,
212+ shard_id , expert_id )
213+ return
214+
215+ expert_data = param .data [expert_id ]
216+ tp_rank = get_tensor_model_parallel_rank ()
217+
218+ # If transposed, weight is saved as [input_dim, output_dim]
219+ # Otherwise, weight is saved as [output_dim, input_dim]
220+ # Default is not transposed/input dim is dim 1
221+ input_dim = getattr (param , "input_dim" , 1 )
222+ output_dim = getattr (param , "output_dim" , 0 )
223+
224+ # Index the loaded weight for tp sharding.
225+ # down_proj: "RowParallel" so tp sharding on input_dim
226+ if shard_id == "w2" :
227+ shard_dim = input_dim
228+ shard_size = expert_data .shape [shard_dim ]
229+ # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
230+ elif shard_id in ("w1" , "w3" ):
231+ shard_dim = output_dim
232+ shard_size = expert_data .shape [output_dim ] // 2
233+ offset = shard_size * tp_rank
234+ loaded_weight = loaded_weight .narrow (shard_dim , offset , shard_size )
235+
236+ # Narrow parameter and load.
237+ # w1, gate_proj: Load into first logical weight of w13.
238+ if shard_id == "w1" :
239+ expert_data = expert_data .narrow (shard_dim , 0 , shard_size )
240+ expert_data .copy_ (loaded_weight )
241+ # w3, up_proj: Load into second logical weight of w13.
242+ elif shard_id == "w3" :
243+ expert_data = expert_data .narrow (shard_dim , shard_size , shard_size )
244+ expert_data .copy_ (loaded_weight )
245+ # w2, down_proj: Load into only logical weight of w2.
246+ elif shard_id == "w2" :
247+ expert_data .copy_ (loaded_weight )
225248 else :
226- tp_rank = get_tensor_model_parallel_rank ()
227- shard_size = self .intermediate_size_per_partition
228- shard = slice (tp_rank * shard_size , (tp_rank + 1 ) * shard_size )
229-
230- # w1, gate_proj case: Load into first shard of w13.
231- if shard_id == 0 :
232- param_data [expert_id ,
233- 0 :shard_size , :] = loaded_weight [shard , :]
234- # w3, up_proj case: Load into second shard of w13.
235- elif shard_id == 2 :
236- param_data [expert_id , shard_size :2 *
237- shard_size , :] = loaded_weight [shard , :]
238- # w2, down_proj case: Load into only shard of w2.
239- elif shard_id == 1 :
240- param_data [expert_id , :, :] = loaded_weight [:, shard ]
241- else :
242- raise ValueError (
243- f"Shard id must be in [0,1,2] but got { shard_id } " )
249+ raise ValueError (
250+ f"Expected shard_id w1,w2 or w3 but got { shard_id } " )
251+
252+ @staticmethod
253+ def select_experts (hidden_states : torch .Tensor ,
254+ router_logits : torch .Tensor ,
255+ top_k : int ,
256+ use_grouped_topk : bool ,
257+ renormalize : bool ,
258+ topk_group : Optional [int ] = None ,
259+ num_expert_group : Optional [int ] = None ):
260+ from vllm .model_executor .layers .fused_moe .fused_moe import (
261+ fused_topk , grouped_topk )
262+
263+ # DeekSeekv2 uses grouped_top_k
264+ if use_grouped_topk :
265+ assert topk_group is not None
266+ assert num_expert_group is not None
267+ topk_weights , topk_ids = grouped_topk (
268+ hidden_states = hidden_states ,
269+ gating_output = router_logits ,
270+ topk = top_k ,
271+ renormalize = renormalize ,
272+ num_expert_group = num_expert_group ,
273+ topk_group = topk_group )
274+ else :
275+ topk_weights , topk_ids = fused_topk (hidden_states = hidden_states ,
276+ gating_output = router_logits ,
277+ topk = top_k ,
278+ renormalize = renormalize )
279+
280+ return topk_weights , topk_ids
244281
245282 def forward (self , hidden_states : torch .Tensor ,
246283 router_logits : torch .Tensor ):
247284 assert self .quant_method is not None
248285
249286 # Matrix multiply.
250287 final_hidden_states = self .quant_method .apply (
251- self ,
288+ layer = self ,
252289 x = hidden_states ,
253290 router_logits = router_logits ,
254291 top_k = self .top_k ,
255292 renormalize = self .renormalize ,
256293 use_grouped_topk = self .use_grouped_topk ,
257- num_expert_group = self .num_expert_group ,
258- topk_group = self .topk_group )
294+ topk_group = self .topk_group ,
295+ num_expert_group = self .num_expert_group )
259296
260297 if self .reduce_results and self .tp_size > 1 :
261298 final_hidden_states = tensor_model_parallel_all_reduce (
@@ -267,35 +304,42 @@ def forward(self, hidden_states: torch.Tensor,
267304 def make_expert_params_mapping (
268305 cls , ckpt_gate_proj_name : str , ckpt_down_proj_name : str ,
269306 ckpt_up_proj_name : str ,
270- num_experts : int ) -> List [Tuple [str , str , int , int ]]:
271-
272- gate_up = [ckpt_gate_proj_name , ckpt_up_proj_name ]
273- gate_down_up = [
274- ckpt_gate_proj_name , ckpt_down_proj_name , ckpt_up_proj_name
275- ]
307+ num_experts : int ) -> List [Tuple [str , str , int , str ]]:
276308
277309 return [
278- # These are the weight scales for the experts
279- # (param_name, weight_name, expert_id, shard_id)
280- ("experts.w13_scale"
281- if weight_name in gate_up else "experts.w2_scale" ,
282- f"experts.{ expert_id } .{ weight_name } .weight_scale" , expert_id ,
283- shard_id ) for expert_id in range (num_experts )
284- for shard_id , weight_name in enumerate (gate_down_up )
285- ] + [
286- # These are the weights for the experts
287310 # (param_name, weight_name, expert_id, shard_id)
288- ("experts.w13_weight"
289- if weight_name in gate_up else "experts.w2_weight" ,
290- f"experts.{ expert_id } .{ weight_name } .weight" , expert_id , shard_id )
291- for expert_id in range (num_experts )
292- for shard_id , weight_name in enumerate (gate_down_up )
293- ] + [
294- # These are the weight scales for the experts
295- # (param_name, weight_name, expert_id, shard_id)
296- ("experts.a13_scale"
297- if weight_name in gate_up else "experts.a2_scale" ,
298- f"experts.{ expert_id } .{ weight_name } .input_scale" , expert_id ,
299- shard_id ) for expert_id in range (num_experts )
300- for shard_id , weight_name in enumerate (gate_down_up )
311+ ("experts.w13_" if weight_name
312+ in [ckpt_gate_proj_name , ckpt_up_proj_name ] else "experts.w2_" ,
313+ f"experts.{ expert_id } .{ weight_name } ." , expert_id , shard_id )
314+ for expert_id in range (num_experts ) for shard_id , weight_name in [
315+ ("w1" , ckpt_gate_proj_name ),
316+ ("w2" , ckpt_down_proj_name ),
317+ ("w3" , ckpt_up_proj_name ),
318+ ]
301319 ]
320+
321+ def _load_fp8_scale (self , param : torch .nn .Parameter ,
322+ loaded_weight : torch .Tensor , weight_name : str ,
323+ shard_id : str , expert_id : int ) -> None :
324+ param_data = param .data
325+
326+ # Input scales can be loaded directly and should be equal.
327+ if "input_scale" in weight_name :
328+ if param_data [expert_id ] != 1 and (param_data [expert_id ] -
329+ loaded_weight ).abs () > 1e-5 :
330+ raise ValueError (
331+ "input_scales of w1 and w3 of a layer "
332+ f"must be equal. But got { param_data [expert_id ]} "
333+ f"vs. { loaded_weight } " )
334+ param_data [expert_id ] = loaded_weight
335+ # Weight scales
336+ elif "weight_scale" in weight_name :
337+ # If we are in merged column case (gate_up_proj)
338+ if shard_id in ("w1" , "w3" ):
339+ # We have to keep the weight scales of w1 and w3 because
340+ # we need to re-quantize w1/w3 weights after weight loading.
341+ idx = 0 if shard_id == "w1" else 1
342+ param_data [expert_id ][idx ] = loaded_weight
343+ # If we are in the row parallel case (down_proj)
344+ else :
345+ param_data [expert_id ] = loaded_weight
0 commit comments