Skip to content

Commit e2bac43

Browse files
wenscarlmgoin
authored andcommitted
[ModelOpt] Load w13/w2_input_scale for all experts, nvfp4 (vllm-project#26135)
Signed-off-by: Shu Wang <[email protected]> Signed-off-by: Shu Wang. <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 79090ca commit e2bac43

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
QuantizationConfig,
5050
QuantizeMethodBase,
5151
)
52+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
53+
is_flashinfer_supporting_global_sf,
54+
)
5255
from vllm.model_executor.utils import set_weight_attrs
5356
from vllm.platforms import current_platform
5457
from vllm.platforms.interface import CpuArchEnum
@@ -1289,6 +1292,7 @@ def __init__(
12891292
"intermediate_size_per_partition": self.intermediate_size_per_partition,
12901293
"params_dtype": params_dtype,
12911294
"weight_loader": self.weight_loader,
1295+
"global_num_experts": self.global_num_experts,
12921296
}
12931297
# need full intermediate size pre-sharding for WNA16 act order
12941298
if self.quant_method.__class__.__name__ in (
@@ -1632,13 +1636,25 @@ def weight_loader(
16321636
param.data[:, :dim1, :dim2].copy_(loaded_weight)
16331637
return True if return_success else None
16341638

1635-
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
1636-
if expert_id == -1:
1639+
quant_method_name = self.quant_method.__class__.__name__
1640+
global_expert_id = expert_id
1641+
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
1642+
1643+
allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
1644+
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)
1645+
1646+
use_global_sf = (
1647+
allow_flashinfer
1648+
and is_flashinfer_supporting_global_sf(moe_backend)
1649+
and "input_scale" in weight_name
1650+
and quant_method_name == "ModelOptNvFp4FusedMoE"
1651+
)
1652+
1653+
if expert_id == -1 and not use_global_sf:
16371654
# Failed to load this param since it's not local to this rank
16381655
return False if return_success else None
16391656
# Hereafter, `expert_id` is local physical id
16401657

1641-
quant_method_name = self.quant_method.__class__.__name__
16421658
# compressed-tensors checkpoints with packed weights are stored flipped
16431659
# TODO (mgoin): check self.quant_method.quant_config.quant_format
16441660
# against known CompressionFormat enum values that have this quality
@@ -1723,7 +1739,9 @@ def weight_loader(
17231739
)
17241740

17251741
self._load_single_value(
1726-
param=param, loaded_weight=loaded_weight, expert_id=expert_id
1742+
param=param,
1743+
loaded_weight=loaded_weight,
1744+
expert_id=global_expert_id if use_global_sf else expert_id,
17271745
)
17281746
return True if return_success else None
17291747

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
5050
flashinfer_cutlass_moe_fp8,
5151
get_flashinfer_moe_backend,
52+
is_flashinfer_supporting_global_sf,
5253
register_moe_scaling_factors,
5354
rotate_flashinfer_fp8_moe_weights,
5455
select_cutlass_fp8_gemm_impl,
@@ -1217,6 +1218,7 @@ def create_weights(
12171218
weight_dtype = torch.uint8
12181219
weight_scale_dtype = torch.float8_e4m3fn
12191220
weight_loader = extra_weight_attrs.get("weight_loader")
1221+
global_num_experts = extra_weight_attrs.get("global_num_experts")
12201222
# GEMM 1
12211223
w13_weight = ModelWeightParameter(
12221224
data=torch.empty(
@@ -1295,14 +1297,19 @@ def create_weights(
12951297
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
12961298
)
12971299

1300+
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
1301+
self.flashinfer_moe_backend
1302+
)
1303+
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
1304+
12981305
w13_input_scale = PerTensorScaleParameter(
1299-
data=torch.empty(num_experts, 2, dtype=torch.float32),
1306+
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
13001307
weight_loader=weight_loader,
13011308
)
13021309
layer.register_parameter("w13_input_scale", w13_input_scale)
13031310

13041311
w2_input_scale = PerTensorScaleParameter(
1305-
data=torch.empty(num_experts, dtype=torch.float32),
1312+
data=torch.empty(global_scale_num_experts, dtype=torch.float32),
13061313
weight_loader=weight_loader,
13071314
)
13081315
layer.register_parameter("w2_input_scale", w2_input_scale)
@@ -1457,7 +1464,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14571464
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
14581465

14591466
# Common processing for input scales and alphas
1460-
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1467+
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
1468+
self.flashinfer_moe_backend
1469+
)
1470+
if use_global_sf:
1471+
# For backends provide by Flashinfer, the input global scales are
1472+
# shared across all experts.
1473+
w13_input_scale = (
1474+
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
1475+
)
1476+
else:
1477+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
14611478
layer.g1_alphas = Parameter(
14621479
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
14631480
requires_grad=False,
@@ -1469,14 +1486,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14691486
)
14701487

14711488
# GEMM 2 processing
1489+
if use_global_sf:
1490+
# For backends provide by Flashinfer, the input global scales are
1491+
# shared across all experts.
1492+
w2_input_scale = (
1493+
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
1494+
)
1495+
else:
1496+
w2_input_scale = layer.w2_input_scale
14721497
layer.g2_alphas = Parameter(
1473-
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1498+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
14741499
requires_grad=False,
14751500
)
14761501

14771502
# This is for quantization, so we need to invert it.
14781503
layer.w2_input_scale_quant = Parameter(
1479-
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
1504+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
14801505
)
14811506

14821507
# TensorRT-LLM specific processing

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,9 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
263263
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
264264
f" expected one of {allowed_backends}"
265265
)
266+
267+
268+
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
269+
# TODO(shuw@nvidia): Update when new backends are added.
270+
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
271+
return backend in backends_supporting_global_sf

0 commit comments

Comments
 (0)