Skip to content

Commit 1e15aae

Browse files
authored
[Bugfix][Quantization] Fix FP8 + EP (#13784)
Signed-off-by: Tyler Michael Smith <[email protected]>
1 parent 51010a1 commit 1e15aae

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class FusedMoE(torch.nn.Module):
260260

261261
def __init__(
262262
self,
263-
num_experts: int,
263+
num_experts: int, # Global number of experts
264264
top_k: int,
265265
hidden_size: int,
266266
intermediate_size: int,
@@ -291,7 +291,8 @@ def __init__(
291291
else:
292292
self.ep_size = 1
293293
self.top_k = top_k
294-
self.num_experts = num_experts # Global number of experts
294+
self.global_num_experts = num_experts
295+
self.local_num_experts = self.global_num_experts // self.ep_size
295296
assert intermediate_size % self.tp_size == 0
296297
self.intermediate_size_per_partition = intermediate_size // self.tp_size
297298
self.reduce_results = reduce_results
@@ -308,39 +309,38 @@ def __init__(
308309

309310
if self.ep_size > 1:
310311
# Create a tensor of size num_experts filled with -1
311-
self.expert_map = torch.full((self.num_experts, ),
312+
self.expert_map = torch.full((self.global_num_experts, ),
312313
-1,
313314
dtype=torch.int32)
314315
# Create a expert map for the local experts
315-
local_num_experts = num_experts // self.ep_size
316316
ep_rank = get_tensor_model_parallel_rank()
317317
if ep_rank < (self.ep_size - 1):
318318
# Each non-last rank gets local_num_experts experts.
319-
self.expert_map[ep_rank * local_num_experts:
320-
(ep_rank + 1) * local_num_experts] = \
321-
torch.arange(0, local_num_experts, dtype=torch.int32)
319+
self.expert_map[ep_rank * self.local_num_experts:
320+
(ep_rank + 1) * self.local_num_experts] = \
321+
torch.arange(0, self.local_num_experts, dtype=torch.int32)
322322
else:
323323
# All remaining experts are assigned to the last rank.
324-
local_num_experts = num_experts - ep_rank * local_num_experts
325-
self.expert_map[-local_num_experts:] = \
326-
torch.arange(0, local_num_experts, dtype=torch.int32)
324+
self.local_num_experts = (self.global_num_experts -
325+
ep_rank * self.local_num_experts)
326+
self.expert_map[-self.local_num_experts:] = \
327+
torch.arange(0, self.local_num_experts, dtype=torch.int32)
327328

328329
if self.scoring_func != "softmax" and not self.use_grouped_topk:
329330
raise ValueError("Only softmax scoring function is supported for "
330331
"non-grouped topk.")
331332

333+
# Note: get_quant_method will look at the layer's local_num_experts
334+
# for heuristic purposes, so it must be initialized first.
332335
if quant_config is None:
333336
self.quant_method: Optional[QuantizeMethodBase] = (
334337
UnquantizedFusedMoEMethod())
335338
else:
336339
self.quant_method = quant_config.get_quant_method(self, prefix)
337340
assert self.quant_method is not None
338341

339-
local_num_experts = torch.sum(self.expert_map != -1) \
340-
if self.expert_map is not None else num_experts
341-
342342
moe_quant_params = {
343-
"num_experts": local_num_experts,
343+
"num_experts": self.local_num_experts,
344344
"hidden_size": hidden_size,
345345
"intermediate_size_per_partition":
346346
self.intermediate_size_per_partition,
@@ -647,7 +647,7 @@ def forward(self, hidden_states: torch.Tensor,
647647
top_k=self.top_k,
648648
renormalize=self.renormalize,
649649
use_grouped_topk=self.use_grouped_topk,
650-
global_num_experts=self.num_experts,
650+
global_num_experts=self.global_num_experts,
651651
expert_map=self.expert_map,
652652
topk_group=self.topk_group,
653653
num_expert_group=self.num_expert_group,

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_quant_method(self, layer: torch.nn.Module,
136136
self.full_config).get_quant_method(layer, prefix)
137137
return AWQMarlinLinearMethod(self)
138138
elif isinstance(layer, FusedMoE):
139-
if layer.num_experts > 32:
139+
if layer.local_num_experts > 32:
140140
# For MoEs with many experts the moe_wna16 kernel is faster
141141
return MoeWNA16Config.from_config(
142142
self.full_config).get_quant_method(layer, prefix)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
190190
assert layer.w13_weight_scale is not None
191191
shard_size = layer.intermediate_size_per_partition
192192
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
193-
for expert_id in range(layer.num_experts):
193+
for expert_id in range(layer.local_num_experts):
194194
start = 0
195195
for shard_id in range(2):
196196
dq_weight = per_tensor_dequantize(

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,11 +573,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
573573
# Re-initialize w13_scale because we directly quantize
574574
# merged w13 weights and generate a single scaling factor.
575575
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
576-
layer.num_experts,
576+
layer.local_num_experts,
577577
dtype=torch.float32,
578578
device=w13_weight.device),
579579
requires_grad=False)
580-
for expert in range(layer.num_experts):
580+
for expert in range(layer.local_num_experts):
581581
w13_weight[expert, :, :], layer.w13_weight_scale[
582582
expert] = ops.scaled_fp8_quant(
583583
layer.w13_weight.data[expert, :, :])
@@ -644,7 +644,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
644644
assert layer.w13_weight_scale is not None
645645
shard_size = layer.intermediate_size_per_partition
646646
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
647-
for expert_id in range(layer.num_experts):
647+
for expert_id in range(layer.local_num_experts):
648648
start = 0
649649
for shard_id in range(2):
650650
dq_weight = per_tensor_dequantize(

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def override_quantization_method(cls, hf_quant_cfg,
153153
def get_quant_method(self, layer: torch.nn.Module,
154154
prefix: str) -> Optional["QuantizeMethodBase"]:
155155
if isinstance(layer, FusedMoE):
156-
if layer.num_experts > 32:
156+
if layer.local_num_experts > 32:
157157
# For MoEs with many experts the moe_wna16 kernel is faster
158158
return MoeWNA16Config.from_config(
159159
self.full_config).get_quant_method(layer, prefix)

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
174174
assert layer.w13_weight_scale is not None
175175
shard_size = layer.intermediate_size_per_partition
176176
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
177-
for expert_id in range(layer.num_experts):
177+
for expert_id in range(layer.local_num_experts):
178178
start = 0
179179
for shard_id in range(2):
180180
dq_weight = per_tensor_dequantize(

0 commit comments

Comments
 (0)