Skip to content

Commit 1ac6359

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranathmgoin
authored andcommitted
[Bugfix] Fix gpt-oss w4a8 DP/EP on B200 (vllm-project#26729)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent e2bac43 commit 1ac6359

File tree

5 files changed

+82
-2
lines changed

5 files changed

+82
-2
lines changed

tests/quantization/test_blackwell_moe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,23 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatc
170170
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
171171
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
172172
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
173+
174+
175+
def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
176+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
177+
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
178+
can_initialize(
179+
"openai/gpt-oss-20b",
180+
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
181+
hf_overrides=HF_OVERRIDE_TEXT,
182+
)
183+
184+
185+
def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
186+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
187+
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
188+
can_initialize(
189+
"openai/gpt-oss-20b",
190+
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
191+
hf_overrides=HF_OVERRIDE_TEXT,
192+
)

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,26 @@ def mxfp4_w4a16_moe_quant_config(
517517
)
518518

519519

520+
def mxfp4_mxfp8_moe_quant_config(
521+
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
522+
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
523+
a1_scale: torch.Tensor | None = None,
524+
a2_scale: torch.Tensor | None = None,
525+
w1_bias: torch.Tensor | None = None,
526+
w2_bias: torch.Tensor | None = None,
527+
block_shape: list[int] | None = None,
528+
) -> FusedMoEQuantConfig:
529+
"""
530+
Construct a quant config for mxfp4 activations and mxfp4 weights.
531+
"""
532+
return FusedMoEQuantConfig(
533+
_a1=FusedMoEQuantDesc("mxfp8"),
534+
_a2=FusedMoEQuantDesc("mxfp8"),
535+
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
536+
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
537+
)
538+
539+
520540
def ocp_mx_moe_quant_config(
521541
quant_dtype: str,
522542
w1_scale: Union[torch.Tensor, "PrecisionConfig"],

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
1919
from vllm.model_executor.layers.fused_moe.config import (
2020
FusedMoEQuantConfig,
21+
mxfp4_mxfp8_moe_quant_config,
2122
mxfp4_w4a16_moe_quant_config,
2223
ocp_mx_moe_quant_config,
2324
)
@@ -747,6 +748,23 @@ def get_fused_moe_quant_config(
747748
w1_scale=w1_scale,
748749
w2_scale=w2_scale,
749750
)
751+
elif self.mxfp4_backend in [
752+
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
753+
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
754+
]:
755+
return mxfp4_mxfp8_moe_quant_config(
756+
w1_bias=layer.w13_bias,
757+
w2_bias=layer.w2_bias,
758+
w1_scale=layer.w13_weight_scale,
759+
w2_scale=layer.w2_weight_scale,
760+
)
761+
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
762+
return mxfp4_w4a16_moe_quant_config(
763+
w1_bias=layer.w13_bias,
764+
w2_bias=layer.w2_bias,
765+
w1_scale=layer.w13_weight_scale,
766+
w2_scale=layer.w2_weight_scale,
767+
)
750768
else:
751769
w1_scale = layer.w13_weight_scale
752770
w2_scale = layer.w2_weight_scale

vllm/model_executor/layers/quantization/utils/mxfp8_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1818
"`pip install flashinfer`"
1919
) from err
2020

21-
return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
21+
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
22+
if x_scales.ndim == 1:
23+
x_scales = x_scales.view(x.size(0), -1)
24+
return x_q, x_scales

vllm/model_executor/warmup/kernel_warmup.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
import vllm.envs as envs
14+
from vllm.config import VllmConfig
1415
from vllm.logger import init_logger
1516
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
1617
from vllm.platforms import current_platform
@@ -24,6 +25,20 @@
2425
logger = init_logger(__name__)
2526

2627

28+
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
29+
"""
30+
Record known issues with vllm + flashinfer autotune here. Return True if
31+
and only if flashinfer autotune will run through without issues.
32+
"""
33+
return not (
34+
vllm_config.parallel_config.data_parallel_size > 1
35+
and (
36+
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
37+
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
38+
)
39+
)
40+
41+
2742
def kernel_warmup(worker: "Worker"):
2843
# Deep GEMM warmup
2944
do_deep_gemm_warmup = (
@@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"):
3752
deep_gemm_warmup(model, max_tokens)
3853

3954
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
40-
if has_flashinfer() and current_platform.has_device_capability(90):
55+
if (
56+
has_flashinfer()
57+
and current_platform.has_device_capability(90)
58+
and flashinfer_autotune_supported(worker.vllm_config)
59+
):
4160
flashinfer_autotune(worker.model_runner)
4261

4362
# FlashInfer attention warmup

0 commit comments

Comments
 (0)