Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/quantization/test_blackwell_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,23 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatc
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)


def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
can_initialize(
"openai/gpt-oss-20b",
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
hf_overrides=HF_OVERRIDE_TEXT,
)


def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
can_initialize(
"openai/gpt-oss-20b",
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
hf_overrides=HF_OVERRIDE_TEXT,
)
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,26 @@ def mxfp4_w4a16_moe_quant_config(
)


def mxfp4_mxfp8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("mxfp8"),
_a2=FusedMoEQuantDesc("mxfp8"),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)


def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
Expand Down
18 changes: 18 additions & 0 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
Expand Down Expand Up @@ -747,6 +748,23 @@ def get_fused_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
)
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
]:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"`pip install flashinfer`"
) from err

return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
if x_scales.ndim == 1:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales
21 changes: 20 additions & 1 deletion vllm/model_executor/warmup/kernel_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
Expand All @@ -24,6 +25,20 @@
logger = init_logger(__name__)


def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
"""
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
return not (
vllm_config.parallel_config.data_parallel_size > 1
and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath could we add this extra requirement?

and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"

In our testing without VLLM_ALL2ALL_BACKEND="deepep_high_throughput", GPT-OSS has no issues

)
Comment on lines +28 to +39
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a skip failling test case to tests/quantization/test_blackwell_moe.py to keep track of known failures

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍 Updated blackwell tests to execute these cases. PTAL! Thanks!



def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
do_deep_gemm_warmup = (
Expand All @@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"):
deep_gemm_warmup(model, max_tokens)

# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
if (
has_flashinfer()
and current_platform.has_device_capability(90)
and flashinfer_autotune_supported(worker.vllm_config)
):
flashinfer_autotune(worker.model_runner)

# FlashInfer attention warmup
Expand Down