Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 63850f2

Browse files
Qwen 2 refactored (#349)
draft --------- Co-authored-by: [email protected] <rshaw@neuralmagic>
1 parent a39d8bf commit 63850f2

File tree

4 files changed

+112
-101
lines changed

4 files changed

+112
-101
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def apply_moe(self,
8989
layer: torch.nn.Module,
9090
x: torch.Tensor,
9191
router_logits: torch.Tensor,
92-
top_k: int) -> torch.Tensor:
92+
top_k: int,
93+
renormalize: bool = True) -> torch.Tensor:
9394
raise NotImplementedError
9495

9596

@@ -157,14 +158,15 @@ def apply_moe(self,
157158
layer: torch.nn.Module,
158159
x: torch.Tensor,
159160
router_logits: torch.Tensor,
160-
top_k: int) -> torch.Tensor:
161+
top_k: int,
162+
renormalize: bool = True) -> torch.Tensor:
161163

162164
return fused_moe(x,
163165
layer.w13_weight,
164166
layer.w2_weight,
165167
router_logits,
166168
top_k,
167-
renormalize=True,
169+
renormalize=renormalize,
168170
inplace=True)
169171

170172

@@ -876,7 +878,9 @@ def __init__(
876878
hidden_size: int,
877879
intermediate_size: int,
878880
params_dtype: Optional[torch.dtype] = None,
879-
quant_config: Optional[QuantizationConfig]=None,
881+
reduce_results: bool = False,
882+
renormalize: bool = True,
883+
quant_config: Optional[QuantizationConfig] = None,
880884
):
881885
super().__init__()
882886

@@ -886,6 +890,8 @@ def __init__(
886890
self.tp_size = get_tensor_model_parallel_world_size()
887891
self.top_k = top_k
888892
self.intermediate_size_per_partition = intermediate_size // self.tp_size
893+
self.reduce_results = reduce_results
894+
self.renormalize = renormalize
889895

890896
if quant_config is None:
891897
self.quant_method: Optional[
@@ -906,22 +912,28 @@ def weight_loader(self,
906912
param: torch.nn.Parameter,
907913
loaded_weight: torch.Tensor,
908914
weight_name: str,
915+
shard_id: int,
909916
expert_id: int):
910917
tp_rank = get_tensor_model_parallel_rank()
911918
param_data = param.data
912919
shard_size = self.intermediate_size_per_partition
913920
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
914921

915-
# FIXME: This is going to be brittle.
916-
if weight_name.endswith("w1.weight"):
922+
# w1, gate_proj case: Load into first shard of w13.
923+
if shard_id == 0:
917924
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
918-
if weight_name.endswith("w3.weight"):
925+
# w3, up_proj case: Load into second shard of w13.
926+
elif shard_id == 2:
919927
param_data[expert_id,
920928
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
921-
if weight_name.endswith("w2.weight"):
929+
# w2, down_proj case: Load into only shard of w2.
930+
elif shard_id == 1:
922931
param_data[expert_id, :, :] = loaded_weight[:, shard]
932+
else:
933+
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
923934

924-
# FIXME: This is going to be brittle.
935+
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
936+
# Follow up PR to enable fp8 for other MoE models.
925937
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
926938
if param_data[expert_id] != 1 and (param_data[expert_id] -
927939
loaded_weight).abs() > 1e-5:
@@ -936,7 +948,6 @@ def weight_loader(self,
936948
assert "w1" in weight_name or "w3" in weight_name
937949
shard_id = 0 if "w1" in weight_name else 1
938950
param_data[expert_id][shard_id] = loaded_weight
939-
940951

941952

942953
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
@@ -945,9 +956,10 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
945956
final_hidden_states = self.quant_method.apply_moe(self,
946957
x=hidden_states,
947958
router_logits=router_logits,
948-
top_k=self.top_k)
949-
950-
if self.tp_size > 1:
959+
top_k=self.top_k,
960+
renormalize=self.renormalize)
961+
962+
if self.reduce_results and self.tp_size > 1:
951963
final_hidden_states = tensor_model_parallel_all_reduce(
952964
final_hidden_states)
953965

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,15 @@ def apply_moe(self,
393393
layer: torch.nn.Module,
394394
x: torch.Tensor,
395395
router_logits: torch.Tensor,
396-
top_k: int) -> torch.Tensor:
396+
top_k: int,
397+
renormalize: bool=True) -> torch.Tensor:
397398

398399
return fused_moe(x,
399400
layer.w13_weight,
400401
layer.w2_weight,
401402
router_logits,
402403
top_k,
403-
renormalize=True,
404+
renormalize=renormalize,
404405
inplace=True,
405406
use_fp8=True,
406407
w1_scale=layer.w13_scale,

vllm/model_executor/models/mixtral.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,22 @@ def __init__(
7878
params_dtype=params_dtype,
7979
quant_config=None)
8080

81-
self.mlp = FusedMoELinear(num_experts=num_experts,
82-
top_k=top_k,
83-
hidden_size=hidden_size,
84-
intermediate_size=intermediate_size,
85-
params_dtype=params_dtype,
86-
quant_config=quant_config)
81+
self.experts = FusedMoELinear(num_experts=num_experts,
82+
top_k=top_k,
83+
hidden_size=hidden_size,
84+
intermediate_size=intermediate_size,
85+
params_dtype=params_dtype,
86+
reduce_results=True,
87+
renormalize=True,
88+
quant_config=quant_config)
8789

8890

8991
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9092
num_tokens, hidden_size = hidden_states.shape
9193
hidden_states = hidden_states.view(-1, self.hidden_size)
9294
# router_logits: (num_tokens, n_experts)
9395
router_logits, _ = self.gate(hidden_states)
94-
final_hidden_states = self.mlp(hidden_states=hidden_states,
95-
router_logits=router_logits)
96-
96+
final_hidden_states = self.experts(hidden_states,router_logits)
9797
return final_hidden_states.view(num_tokens, hidden_size)
9898

9999

@@ -372,25 +372,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
372372

373373
expert_params_mapping = [
374374
# These are the weight scales for the experts
375-
# (param_name, weight_name, expert_id)
376-
("mlp.w13_scale" if weight_name in ["w1", "w3"] else "mlp.w2_scale",
377-
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
375+
# (param_name, weight_name, expert_id, shard_id)
376+
("experts.w13_scale" if weight_name in ["w1", "w3"] else "experts.w2_scale",
377+
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, shard_id)
378378
for expert_id in range(self.config.num_local_experts)
379-
for weight_name in ["w1", "w2", "w3"]
379+
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
380380
] + [
381381
# These are the weights for the experts
382382
# (param_name, weight_name, expert_id)
383-
("mlp.w13_weight" if weight_name in ["w1", "w3"] else "mlp.w2_weight",
384-
f"experts.{expert_id}.{weight_name}.weight", expert_id)
383+
("experts.w13_weight" if weight_name in ["w1", "w3"] else "experts.w2_weight",
384+
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
385385
for expert_id in range(self.config.num_local_experts)
386-
for weight_name in ["w1", "w2", "w3"]
386+
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
387387
] + [
388388
# These are the activation scales for the experts
389389
# (param_name, weight_name, expert_id)
390-
("mlp.a13_scale" if weight_name in ["w1", "w3"] else "mlp.a2_scale",
391-
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
390+
("experts.a13_scale" if weight_name in ["w1", "w3"] else "experts.a2_scale",
391+
f"experts.{expert_id}.{weight_name}.input_scale", expert_id, shard_id)
392392
for expert_id in range(self.config.num_local_experts)
393-
for weight_name in ["w1", "w2", "w3"]
393+
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
394394
]
395395

396396
params_dict = dict(self.named_parameters())
@@ -410,7 +410,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
410410
weight_loader(param, loaded_weight, shard_id)
411411
break
412412
else:
413-
for param_name, weight_name, expert_id in expert_params_mapping:
413+
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
414414
if weight_name not in name:
415415
continue
416416
name = name.replace(weight_name, param_name)
@@ -419,6 +419,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
419419
weight_loader(param,
420420
loaded_weight,
421421
weight_name,
422+
shard_id=shard_id,
422423
expert_id=expert_id)
423424
break
424425
else:

vllm/model_executor/models/qwen2_moe.py

Lines changed: 63 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
get_tensor_model_parallel_world_size,
3636
tensor_model_parallel_all_reduce)
3737
from vllm.model_executor.layers.activation import SiluAndMul
38-
from vllm.model_executor.layers.fused_moe import fused_moe
3938
from vllm.model_executor.layers.layernorm import RMSNorm
40-
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39+
from vllm.model_executor.layers.linear import (FusedMoELinear,
40+
MergedColumnParallelLinear,
4141
QKVParallelLinear,
4242
ReplicatedLinear,
4343
RowParallelLinear)
@@ -93,25 +93,22 @@ def __init__(
9393
quant_config: Optional[QuantizationConfig] = None,
9494
):
9595
super().__init__()
96-
self.config = config
97-
self.rank = get_tensor_model_parallel_rank()
98-
self.tp_size = get_tensor_model_parallel_world_size()
99-
self.n_routed_experts = config.num_experts
100-
self.top_k = config.num_experts_per_tok
101-
if self.tp_size > self.n_routed_experts:
96+
self.tp_size = get_tensor_model_parallel_world_size()
97+
98+
if self.tp_size > config.num_experts:
10299
raise ValueError(
103100
f"Tensor parallel size {self.tp_size} is greater than "
104-
f"the number of experts {self.n_routed_experts}.")
105-
106-
self.experts = nn.ModuleList([
107-
Qwen2MoeMLP(hidden_size=config.hidden_size,
108-
intermediate_size=config.moe_intermediate_size,
109-
hidden_act=config.hidden_act,
110-
quant_config=quant_config,
111-
reduce_results=False)
112-
for idx in range(self.n_routed_experts)
113-
])
114-
self.pack_params()
101+
f"the number of experts {config.num_experts}.")
102+
103+
self.experts = FusedMoELinear(
104+
num_experts=config.num_experts,
105+
top_k=config.num_experts_per_tok,
106+
hidden_size=config.hidden_size,
107+
intermediate_size=config.moe_intermediate_size,
108+
reduce_results=False,
109+
renormalize=config.norm_topk_prob,
110+
quant_config=quant_config,
111+
)
115112

116113
self.gate = ReplicatedLinear(config.hidden_size,
117114
self.n_routed_experts,
@@ -131,25 +128,6 @@ def __init__(
131128
1,
132129
bias=False)
133130

134-
def pack_params(self):
135-
w1 = []
136-
w2 = []
137-
for expert in self.experts:
138-
w1.append(expert.gate_up_proj.weight)
139-
w2.append(expert.down_proj.weight)
140-
self.w1 = torch._utils._flatten_dense_tensors(w1)
141-
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
142-
for data, param in zip(w1s, w1):
143-
param.data = data
144-
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
145-
146-
self.w2 = torch._utils._flatten_dense_tensors(w2)
147-
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
148-
for data, param in zip(w2s, w2):
149-
param.data = data
150-
151-
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
152-
153131
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
154132
num_tokens, hidden_dim = hidden_states.shape
155133
hidden_states = hidden_states.view(-1, hidden_dim)
@@ -162,18 +140,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162140

163141
# router_logits: (num_tokens, n_experts)
164142
router_logits, _ = self.gate(hidden_states)
165-
final_hidden_states = fused_moe(hidden_states,
166-
self.w1,
167-
self.w2,
168-
router_logits,
169-
self.top_k,
170-
renormalize=self.config.norm_topk_prob,
171-
inplace=True)
172-
143+
final_hidden_states = self.experts(hidden_states=hidden_states,
144+
router_logits=router_logits)
173145
if shared_output is not None:
174146
final_hidden_states = final_hidden_states + shared_output
175-
final_hidden_states = tensor_model_parallel_all_reduce(
176-
final_hidden_states)
147+
if self.tp_size > 1:
148+
final_hidden_states = tensor_model_parallel_all_reduce(
149+
final_hidden_states)
177150

178151
return final_hidden_states.view(num_tokens, hidden_dim)
179152

@@ -284,6 +257,7 @@ def __init__(
284257
cache_config=cache_config,
285258
quant_config=quant_config,
286259
)
260+
287261
if (layer_idx not in config.mlp_only_layers) and (
288262
config.num_experts > 0 and
289263
(layer_idx + 1) % config.decoder_sparse_step == 0):
@@ -426,21 +400,35 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
426400
("gate_up_proj", "up_proj", 1),
427401
]
428402

403+
expert_params_mapping = [
404+
# These are the weights for the experts
405+
# (param_name, weight_name, expert_id, shard_id)
406+
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] else "experts.w2_weight",
407+
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
408+
for expert_id in range(self.config.num_experts)
409+
for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
410+
]
411+
429412
params_dict = dict(self.named_parameters())
430413
for name, loaded_weight in weights:
431414
if "rotary_emb.inv_freq" in name:
432415
continue
433416
for (param_name, weight_name, shard_id) in stacked_params_mapping:
417+
# Skip non-stacked and experts (experts handled below).
434418
if weight_name not in name:
435419
continue
420+
# We have mlp.experts[0].gate_proj in the checkpoint.
421+
# Since we handle the experts below in expert_params_mapping,
422+
# we need to skip here BEFORE we update the name, otherwise
423+
# name will be updated to mlp.experts[0].gate_up_proj, which
424+
# will then be updated below in expert_params_mapping
425+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
426+
if "mlp.experts" in name:
427+
continue
436428
name = name.replace(weight_name, param_name)
437429
# Skip loading extra bias for GPTQ models.
438430
if name.endswith(".bias") and name not in params_dict:
439431
continue
440-
# Skip experts that are not assigned to this worker.
441-
if (("mlp.experts." in name or "mlp.shared_expert." in name)
442-
and name not in params_dict):
443-
continue
444432
if name not in params_dict:
445433
continue
446434

@@ -449,17 +437,26 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
449437
weight_loader(param, loaded_weight, shard_id)
450438
break
451439
else:
452-
# Skip loading extra bias for GPTQ models.
453-
if name.endswith(".bias") and name not in params_dict:
454-
continue
455-
# Skip experts that are not assigned to this worker.
456-
if (("mlp.experts." in name or "mlp.shared_expert." in name)
457-
and name not in params_dict):
458-
continue
459-
if name not in params_dict:
460-
continue
461-
462-
param = params_dict[name]
463-
weight_loader = getattr(param, "weight_loader",
464-
default_weight_loader)
465-
weight_loader(param, loaded_weight)
440+
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
441+
if weight_name not in name:
442+
continue
443+
name = name.replace(weight_name, param_name)
444+
param = params_dict[name]
445+
weight_loader = param.weight_loader
446+
weight_loader(param,
447+
loaded_weight,
448+
weight_name,
449+
shard_id=shard_id,
450+
expert_id=expert_id)
451+
break
452+
else:
453+
# Skip loading extra bias for GPTQ models.
454+
if name.endswith(".bias") and name not in params_dict:
455+
continue
456+
if name not in params_dict:
457+
continue
458+
459+
param = params_dict[name]
460+
weight_loader = getattr(param, "weight_loader",
461+
default_weight_loader)
462+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)