Skip to content

Commit d3bdfd3

Browse files
authored
[Misc] Update Fused MoE weight loading (#7334)
1 parent fb377d7 commit d3bdfd3

File tree

6 files changed

+264
-201
lines changed

6 files changed

+264
-201
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 180 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
2424
raise NotImplementedError
2525

2626
@abstractmethod
27-
def apply(self,
28-
layer: torch.nn.Module,
29-
x: torch.Tensor,
30-
router_logits: torch.Tensor,
31-
top_k: int,
32-
renormalize: bool = True,
33-
use_grouped_topk: bool = False,
34-
num_expert_group: Optional[int] = None,
35-
topk_group: Optional[int] = None) -> torch.Tensor:
27+
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
28+
router_logits: torch.Tensor, top_k: int, renormalize: bool,
29+
use_grouped_topk: bool) -> torch.Tensor:
3630
raise NotImplementedError
3731

3832

@@ -61,66 +55,78 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
6155
layer.register_parameter("w2_weight", w2_weight)
6256
set_weight_attrs(w2_weight, extra_weight_attrs)
6357

64-
def apply(
65-
self,
66-
layer: torch.nn.Module,
67-
x: torch.Tensor,
68-
router_logits: torch.Tensor,
69-
top_k: int,
70-
renormalize: bool = True,
71-
use_grouped_topk: bool = False,
72-
num_expert_group: Optional[int] = None,
73-
topk_group: Optional[int] = None,
74-
) -> torch.Tensor:
75-
return self.forward(x, layer.w13_weight, layer.w2_weight,
76-
router_logits, top_k, renormalize,
77-
use_grouped_topk, num_expert_group, topk_group)
78-
79-
def forward_cuda(
80-
self,
81-
x: torch.Tensor,
82-
w1: torch.Tensor,
83-
w2: torch.Tensor,
84-
router_logits: torch.Tensor,
85-
top_k: int,
86-
renormalize: bool,
87-
use_grouped_topk: bool,
88-
num_expert_group: Optional[int],
89-
topk_group: Optional[int],
90-
) -> torch.Tensor:
91-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
92-
return fused_moe(x,
93-
w1,
94-
w2,
95-
router_logits,
96-
top_k,
97-
renormalize=renormalize,
98-
inplace=True,
99-
use_grouped_topk=use_grouped_topk,
100-
num_expert_group=num_expert_group,
101-
topk_group=topk_group)
58+
def apply(self,
59+
layer: torch.nn.Module,
60+
x: torch.Tensor,
61+
router_logits: torch.Tensor,
62+
top_k: int,
63+
renormalize: bool,
64+
use_grouped_topk: bool,
65+
topk_group: Optional[int] = None,
66+
num_expert_group: Optional[int] = None) -> torch.Tensor:
67+
68+
return self.forward(x=x,
69+
layer=layer,
70+
router_logits=router_logits,
71+
top_k=top_k,
72+
renormalize=renormalize,
73+
use_grouped_topk=use_grouped_topk,
74+
topk_group=topk_group,
75+
num_expert_group=num_expert_group)
76+
77+
def forward_cuda(self,
78+
layer: torch.nn.Module,
79+
x: torch.Tensor,
80+
use_grouped_topk: bool,
81+
top_k: int,
82+
router_logits: torch.Tensor,
83+
renormalize: bool,
84+
topk_group: Optional[int] = None,
85+
num_expert_group: Optional[int] = None) -> torch.Tensor:
86+
87+
from vllm.model_executor.layers.fused_moe.fused_moe import (
88+
fused_experts)
89+
90+
topk_weights, topk_ids = FusedMoE.select_experts(
91+
hidden_states=x,
92+
router_logits=router_logits,
93+
use_grouped_topk=use_grouped_topk,
94+
top_k=top_k,
95+
renormalize=renormalize,
96+
topk_group=topk_group,
97+
num_expert_group=num_expert_group)
98+
99+
return fused_experts(hidden_states=x,
100+
w1=layer.w13_weight,
101+
w2=layer.w2_weight,
102+
topk_weights=topk_weights,
103+
topk_ids=topk_ids,
104+
inplace=True)
102105

103106
def forward_cpu(self, *args, **kwargs):
104107
raise NotImplementedError(
105108
"The CPU backend currently does not support MoE.")
106109

107-
def forward_tpu(
108-
self,
109-
x: torch.Tensor,
110-
w1: torch.Tensor,
111-
w2: torch.Tensor,
112-
router_logits: torch.Tensor,
113-
top_k: int,
114-
renormalize: bool,
115-
use_grouped_topk: bool,
116-
num_expert_group: Optional[int],
117-
topk_group: Optional[int],
118-
) -> torch.Tensor:
110+
def forward_tpu(self,
111+
layer: torch.nn.Module,
112+
x: torch.Tensor,
113+
use_grouped_topk: bool,
114+
top_k: int,
115+
router_logits: torch.Tensor,
116+
renormalize: bool,
117+
topk_group: Optional[int] = None,
118+
num_expert_group: Optional[int] = None) -> torch.Tensor:
119+
119120
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
120121
assert not use_grouped_topk
121122
assert num_expert_group is None
122123
assert topk_group is None
123-
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
124+
return fused_moe(hidden_states=x,
125+
w1=layer.w13_weight,
126+
w2=layer.w2_weight,
127+
topk=top_k,
128+
gating_output=router_logits,
129+
renormalize=renormalize)
124130

125131

126132
class FusedMoE(torch.nn.Module):
@@ -195,67 +201,98 @@ def __init__(
195201

196202
def weight_loader(self, param: torch.nn.Parameter,
197203
loaded_weight: torch.Tensor, weight_name: str,
198-
shard_id: int, expert_id: int):
199-
param_data = param.data
200-
201-
# Input scales can be loaded directly and should be equal.
202-
if "input_scale" in weight_name:
203-
if param_data[expert_id] != 1 and (param_data[expert_id] -
204-
loaded_weight).abs() > 1e-5:
205-
raise ValueError(
206-
"input_scales of w1 and w3 of a layer "
207-
f"must be equal. But got {param_data[expert_id]} "
208-
f"vs. {loaded_weight}")
209-
param_data[expert_id] = loaded_weight
210-
# Weight scales
211-
elif "weight_scale" in weight_name:
212-
# If we are in merged column case (gate_up_proj)
213-
# shard_id 0 == gate_proj / w1
214-
# shard_id 2 == up_proj / w3
215-
if shard_id == 0 or shard_id == 2:
216-
# We have to keep the weight scales of w1 and w3 because
217-
# we need to re-quantize w1/w3 weights after weight loading.
218-
idx = 0 if shard_id == 0 else 1
219-
param_data[expert_id][idx] = loaded_weight
220-
# If we are in the row parallel case (down_proj)
221-
# shard_id 1 == down_proj / w2
222-
else:
223-
param_data[expert_id] = loaded_weight
224-
# Weights
204+
shard_id: str, expert_id: int) -> None:
205+
if shard_id not in ("w1", "w2", "w3"):
206+
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
207+
f"got {shard_id}.")
208+
209+
# Special case for fp8 scales.
210+
if getattr(param, "is_fp8_scale", False):
211+
self._load_fp8_scale(param.data, loaded_weight, weight_name,
212+
shard_id, expert_id)
213+
return
214+
215+
expert_data = param.data[expert_id]
216+
tp_rank = get_tensor_model_parallel_rank()
217+
218+
# If transposed, weight is saved as [input_dim, output_dim]
219+
# Otherwise, weight is saved as [output_dim, input_dim]
220+
# Default is not transposed/input dim is dim 1
221+
input_dim = getattr(param, "input_dim", 1)
222+
output_dim = getattr(param, "output_dim", 0)
223+
224+
# Index the loaded weight for tp sharding.
225+
# down_proj: "RowParallel" so tp sharding on input_dim
226+
if shard_id == "w2":
227+
shard_dim = input_dim
228+
shard_size = expert_data.shape[shard_dim]
229+
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
230+
elif shard_id in ("w1", "w3"):
231+
shard_dim = output_dim
232+
shard_size = expert_data.shape[output_dim] // 2
233+
offset = shard_size * tp_rank
234+
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
235+
236+
# Narrow parameter and load.
237+
# w1, gate_proj: Load into first logical weight of w13.
238+
if shard_id == "w1":
239+
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
240+
expert_data.copy_(loaded_weight)
241+
# w3, up_proj: Load into second logical weight of w13.
242+
elif shard_id == "w3":
243+
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
244+
expert_data.copy_(loaded_weight)
245+
# w2, down_proj: Load into only logical weight of w2.
246+
elif shard_id == "w2":
247+
expert_data.copy_(loaded_weight)
225248
else:
226-
tp_rank = get_tensor_model_parallel_rank()
227-
shard_size = self.intermediate_size_per_partition
228-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
229-
230-
# w1, gate_proj case: Load into first shard of w13.
231-
if shard_id == 0:
232-
param_data[expert_id,
233-
0:shard_size, :] = loaded_weight[shard, :]
234-
# w3, up_proj case: Load into second shard of w13.
235-
elif shard_id == 2:
236-
param_data[expert_id, shard_size:2 *
237-
shard_size, :] = loaded_weight[shard, :]
238-
# w2, down_proj case: Load into only shard of w2.
239-
elif shard_id == 1:
240-
param_data[expert_id, :, :] = loaded_weight[:, shard]
241-
else:
242-
raise ValueError(
243-
f"Shard id must be in [0,1,2] but got {shard_id}")
249+
raise ValueError(
250+
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
251+
252+
@staticmethod
253+
def select_experts(hidden_states: torch.Tensor,
254+
router_logits: torch.Tensor,
255+
top_k: int,
256+
use_grouped_topk: bool,
257+
renormalize: bool,
258+
topk_group: Optional[int] = None,
259+
num_expert_group: Optional[int] = None):
260+
from vllm.model_executor.layers.fused_moe.fused_moe import (
261+
fused_topk, grouped_topk)
262+
263+
# DeekSeekv2 uses grouped_top_k
264+
if use_grouped_topk:
265+
assert topk_group is not None
266+
assert num_expert_group is not None
267+
topk_weights, topk_ids = grouped_topk(
268+
hidden_states=hidden_states,
269+
gating_output=router_logits,
270+
topk=top_k,
271+
renormalize=renormalize,
272+
num_expert_group=num_expert_group,
273+
topk_group=topk_group)
274+
else:
275+
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
276+
gating_output=router_logits,
277+
topk=top_k,
278+
renormalize=renormalize)
279+
280+
return topk_weights, topk_ids
244281

245282
def forward(self, hidden_states: torch.Tensor,
246283
router_logits: torch.Tensor):
247284
assert self.quant_method is not None
248285

249286
# Matrix multiply.
250287
final_hidden_states = self.quant_method.apply(
251-
self,
288+
layer=self,
252289
x=hidden_states,
253290
router_logits=router_logits,
254291
top_k=self.top_k,
255292
renormalize=self.renormalize,
256293
use_grouped_topk=self.use_grouped_topk,
257-
num_expert_group=self.num_expert_group,
258-
topk_group=self.topk_group)
294+
topk_group=self.topk_group,
295+
num_expert_group=self.num_expert_group)
259296

260297
if self.reduce_results and self.tp_size > 1:
261298
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -267,35 +304,42 @@ def forward(self, hidden_states: torch.Tensor,
267304
def make_expert_params_mapping(
268305
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
269306
ckpt_up_proj_name: str,
270-
num_experts: int) -> List[Tuple[str, str, int, int]]:
271-
272-
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
273-
gate_down_up = [
274-
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
275-
]
307+
num_experts: int) -> List[Tuple[str, str, int, str]]:
276308

277309
return [
278-
# These are the weight scales for the experts
279-
# (param_name, weight_name, expert_id, shard_id)
280-
("experts.w13_scale"
281-
if weight_name in gate_up else "experts.w2_scale",
282-
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
283-
shard_id) for expert_id in range(num_experts)
284-
for shard_id, weight_name in enumerate(gate_down_up)
285-
] + [
286-
# These are the weights for the experts
287310
# (param_name, weight_name, expert_id, shard_id)
288-
("experts.w13_weight"
289-
if weight_name in gate_up else "experts.w2_weight",
290-
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
291-
for expert_id in range(num_experts)
292-
for shard_id, weight_name in enumerate(gate_down_up)
293-
] + [
294-
# These are the weight scales for the experts
295-
# (param_name, weight_name, expert_id, shard_id)
296-
("experts.a13_scale"
297-
if weight_name in gate_up else "experts.a2_scale",
298-
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
299-
shard_id) for expert_id in range(num_experts)
300-
for shard_id, weight_name in enumerate(gate_down_up)
311+
("experts.w13_" if weight_name
312+
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
313+
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
314+
for expert_id in range(num_experts) for shard_id, weight_name in [
315+
("w1", ckpt_gate_proj_name),
316+
("w2", ckpt_down_proj_name),
317+
("w3", ckpt_up_proj_name),
318+
]
301319
]
320+
321+
def _load_fp8_scale(self, param: torch.nn.Parameter,
322+
loaded_weight: torch.Tensor, weight_name: str,
323+
shard_id: str, expert_id: int) -> None:
324+
param_data = param.data
325+
326+
# Input scales can be loaded directly and should be equal.
327+
if "input_scale" in weight_name:
328+
if param_data[expert_id] != 1 and (param_data[expert_id] -
329+
loaded_weight).abs() > 1e-5:
330+
raise ValueError(
331+
"input_scales of w1 and w3 of a layer "
332+
f"must be equal. But got {param_data[expert_id]} "
333+
f"vs. {loaded_weight}")
334+
param_data[expert_id] = loaded_weight
335+
# Weight scales
336+
elif "weight_scale" in weight_name:
337+
# If we are in merged column case (gate_up_proj)
338+
if shard_id in ("w1", "w3"):
339+
# We have to keep the weight scales of w1 and w3 because
340+
# we need to re-quantize w1/w3 weights after weight loading.
341+
idx = 0 if shard_id == "w1" else 1
342+
param_data[expert_id][idx] = loaded_weight
343+
# If we are in the row parallel case (down_proj)
344+
else:
345+
param_data[expert_id] = loaded_weight

0 commit comments

Comments
 (0)