Skip to content

Commit 597672d

Browse files
committed
Load w13/w2_input_scale for all experts
Signed-off-by: Shu Wang <[email protected]>
1 parent 2935092 commit 597672d

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,7 @@ def __init__(
12261226
"intermediate_size_per_partition": self.intermediate_size_per_partition,
12271227
"params_dtype": params_dtype,
12281228
"weight_loader": self.weight_loader,
1229+
"global_num_experts": self.global_num_experts,
12291230
}
12301231
# need full intermediate size pre-sharding for WNA16 act order
12311232
if self.quant_method.__class__.__name__ in (
@@ -1546,13 +1547,16 @@ def weight_loader(
15461547
param.data[:, :dim1, :dim2].copy_(loaded_weight)
15471548
return True if return_success else None
15481549

1549-
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
1550-
if expert_id == -1:
1550+
quant_method_name = self.quant_method.__class__.__name__
1551+
global_expert_id = expert_id
1552+
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
1553+
is_modeloptnvfp4 = quant_method_name == "ModelOptNvFp4FusedMoE"
1554+
is_input_scale = "input_scale" in weight_name
1555+
if expert_id == -1 and not (is_modeloptnvfp4 and is_input_scale):
15511556
# Failed to load this param since it's not local to this rank
15521557
return False if return_success else None
15531558
# Hereafter, `expert_id` is local physical id
15541559

1555-
quant_method_name = self.quant_method.__class__.__name__
15561560
# compressed-tensors checkpoints with packed weights are stored flipped
15571561
# TODO (mgoin): check self.quant_method.quant_config.quant_format
15581562
# against known CompressionFormat enum values that have this quality
@@ -1621,7 +1625,7 @@ def weight_loader(
16211625
expert_data = param.data if full_load else param.data[expert_id]
16221626

16231627
# Case input scale: input_scale loading is only supported for fp8
1624-
if "input_scale" in weight_name:
1628+
if is_input_scale:
16251629
# this is needed for compressed-tensors only
16261630
loaded_weight = loaded_weight.to(param.data.device)
16271631

@@ -1637,7 +1641,9 @@ def weight_loader(
16371641
)
16381642

16391643
self._load_single_value(
1640-
param=param, loaded_weight=loaded_weight, expert_id=expert_id
1644+
param=param,
1645+
loaded_weight=loaded_weight,
1646+
expert_id=global_expert_id if is_modeloptnvfp4 else expert_id,
16411647
)
16421648
return True if return_success else None
16431649

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ def create_weights(
12241224
weight_dtype = torch.uint8
12251225
weight_scale_dtype = torch.float8_e4m3fn
12261226
weight_loader = extra_weight_attrs.get("weight_loader")
1227+
global_num_experts = extra_weight_attrs.get("global_num_experts")
12271228
# GEMM 1
12281229
w13_weight = ModelWeightParameter(
12291230
data=torch.empty(
@@ -1303,15 +1304,16 @@ def create_weights(
13031304
)
13041305

13051306
w13_input_scale = PerTensorScaleParameter(
1306-
data=torch.empty(num_experts, 2, dtype=torch.float32),
1307+
data=torch.empty(global_num_experts, 2, dtype=torch.float32),
13071308
weight_loader=weight_loader,
13081309
)
13091310
layer.register_parameter("w13_input_scale", w13_input_scale)
13101311

13111312
w2_input_scale = PerTensorScaleParameter(
1312-
data=torch.empty(num_experts, dtype=torch.float32),
1313+
data=torch.empty(global_num_experts, dtype=torch.float32),
13131314
weight_loader=weight_loader,
13141315
)
1316+
13151317
layer.register_parameter("w2_input_scale", w2_input_scale)
13161318

13171319
def prepare_static_weights_for_trtllm_fp4_moe(
@@ -1464,7 +1466,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14641466
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
14651467

14661468
# Common processing for input scales and alphas
1467-
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1469+
w13_input_scale = (
1470+
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
1471+
)
14681472
layer.g1_alphas = Parameter(
14691473
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
14701474
requires_grad=False,
@@ -1476,14 +1480,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14761480
)
14771481

14781482
# GEMM 2 processing
1483+
w2_input_scale = (
1484+
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
1485+
)
14791486
layer.g2_alphas = Parameter(
1480-
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1487+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
14811488
requires_grad=False,
14821489
)
14831490

14841491
# This is for quantization, so we need to invert it.
14851492
layer.w2_input_scale_quant = Parameter(
1486-
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
1493+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
14871494
)
14881495

14891496
# TensorRT-LLM specific processing

0 commit comments

Comments
 (0)