Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
Expand Down Expand Up @@ -1226,6 +1229,7 @@ def __init__(
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"global_num_experts": self.global_num_experts,
}
# need full intermediate size pre-sharding for WNA16 act order
if self.quant_method.__class__.__name__ in (
Expand Down Expand Up @@ -1546,13 +1550,25 @@ def weight_loader(
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return True if return_success else None

expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
quant_method_name = self.quant_method.__class__.__name__
global_expert_id = expert_id
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)

allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)

use_global_sf = (
allow_flashinfer
and is_flashinfer_supporting_global_sf(moe_backend)
and "input_scale" in weight_name
and quant_method_name == "ModelOptNvFp4FusedMoE"
)
Comment on lines +1643 to +1651
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to put these three lines together and leave a comment on what use_global_sf means in this case since we are in fused_moe/layer.py


if expert_id == -1 and not use_global_sf:
# Failed to load this param since it's not local to this rank
return False if return_success else None
# Hereafter, `expert_id` is local physical id

quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Expand Down Expand Up @@ -1637,7 +1653,9 @@ def weight_loader(
)

self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
param=param,
loaded_weight=loaded_weight,
expert_id=global_expert_id if use_global_sf else expert_id,
)
return True if return_success else None

Expand Down
35 changes: 30 additions & 5 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
Expand Down Expand Up @@ -1224,6 +1225,7 @@ def create_weights(
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
global_num_experts = extra_weight_attrs.get("global_num_experts")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
Expand Down Expand Up @@ -1302,14 +1304,19 @@ def create_weights(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)

use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
self.flashinfer_moe_backend
)
global_scale_num_experts = global_num_experts if use_global_sf else num_experts

w13_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)

w2_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
data=torch.empty(global_scale_num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
Expand Down Expand Up @@ -1464,7 +1471,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

# Common processing for input scales and alphas
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
self.flashinfer_moe_backend
)
if use_global_sf:
# For backends provide by Flashinfer, the input global scales are
# shared across all experts.
w13_input_scale = (
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
Expand All @@ -1476,14 +1493,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)

# GEMM 2 processing
if use_global_sf:
# For backends provide by Flashinfer, the input global scales are
# shared across all experts.
w2_input_scale = (
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
)
else:
w2_input_scale = layer.w2_input_scale
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)

# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)

# TensorRT-LLM specific processing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,9 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}"
)


def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
# TODO(shuw@nvidia): Update when new backends are added.
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
return backend in backends_supporting_global_sf