Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Oct 13, 2025

Purpose

Running

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1  VLLM_ALL2ALL_BACKEND="deepep_high_throughput"  vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel   --no-enable-prefix-caching 

From main runs into a few issues,

  1. Requirement for amd-quark: Note that despite explicitly asking for VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8, i.e. MXFP4 weights and MXFP8 activations. The code path attempts to quantize the activation to MXFP4. At the moment, we need the package amd-quark for MXFP4 quantization.

Solution: Fix Mxfp4MoEMethod::get_fused_moe_quant_config to return desired quant configs for [SM100_FI_MXFP4_MXFP8_TRTLLM, SM100_FI_MXFP4_MXFP8_CUTLASS, SM100_FI_MXFP4_BF16] Mxfp4Backends explicitly. On main, it will default to ocp_mx_moe_quant_config, which is incorrect.

  1. Incorrect activation quantization scale shape. The activations are quantized to mxfp8 using mxfp8_utils.py::mxfp8_e4m3_quantize. This function returns the scales as a 1D tensor. The rest of the modular kernel code expects the scales to be 2D when the activations are block quantized (which is essentially what mxfp8 quantization will do). Example, when the activations are of shape [16384, 3072], mxfp8_e4m3_quantize will return scales of shape [1572864], but the modular kernel code expects it to be [16384, 96].

Solution: Reshape the scales in mxfp8_e4m3_quantize

  1. When Run with DP & SM100_FI_MXFP4_MXFP8_TRTLLM or SM100_FI_MXFP4_BF16 - flashinfer autotune blows up with an error. Note that this doesn't happen without DP.

Stop-gap fix: Don't do flashinfer autotune in this configuration.

Test Plan

server commands:

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 VLLM_ALL2ALL_BACKEND="deepep_high_throughput" vllm serve openai/gpt-oss-20b --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1  VLLM_ALL2ALL_BACKEND="deepep_high_throughput"  vllm serve openai/gpt-oss-20b --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching --port 9010 
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1   vllm serve openai/gpt-oss-20b --data-parallel-size 1 --tensor-parallel-size 2 --no-enable-prefix-caching --port 9010

gpt-oss eval command:

 OPENAI_API_KEY=empty python -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 128 --base-url http://localhost:9010/v1 --reasoning-effort low

Test Result

All servers yield around 0.56 GPQA value which is expected.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fixes for running gpt-oss with w4a8 quantization on B200 hardware, specifically addressing issues with data and expert parallelism. The changes correctly route to the new w4a8 quantization configuration, fix an incorrect activation scale shape, and add a stop-gap to prevent flashinfer autotuning from failing in specific configurations. The fixes appear correct and well-targeted. I have one suggestion to make the autotune workaround more precise to avoid potential performance regressions on other hardware.

Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
if (vllm_config.parallel_config.data_parallel_size > 1 and (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition to disable flashinfer autotune is a bit broad. Since the issue is reported on B200 (SM100), it would be safer to scope this workaround to only SM100 devices. This will prevent autotune from being unnecessarily disabled on other hardware (like SM90) where it might work correctly.

Suggested change
if (vllm_config.parallel_config.data_parallel_size > 1 and (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8)):
if (vllm_config.parallel_config.data_parallel_size > 1 and
current_platform.is_device_capability(100) and
(envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8)):

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 519 to 538
def mxfp4_w4a8_moe_quant_config(
quant_dtype: str,
weight_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(quant_dtype),
_a2=FusedMoEQuantDesc(quant_dtype),
_w1=FusedMoEQuantDesc(weight_dtype, None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, None, w2_scale, None, None, w2_bias),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Provide block shape for mxfp8 activation scales

The new mxfp4_w4a8_moe_quant_config returns a FusedMoEQuantConfig with activations typed as "mxfp8" but leaves block_shape unset. After the accompanying change in mxfp8_e4m3_quantize reshapes the activation scales into a (num_tokens, k_tiles) tensor, pplx_prepare_finalize.moe_kernel_quantize_input still sees block_shape=None and _validate_scale_shape asserts that non‑per‑token quantization must have exactly one scale. In practice the w4a8 path now produces multi‑element scales and will hit this assertion before routing experts. The config should describe the block quantization (e.g. [1, 32]) or otherwise mark per‑token quantization so downstream validation and buffer sizing accept the reshaped scales.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Canonically for mxfp4 we don't provide a block_shape as the block_shape of 32 is in the mxfp4 standard itself and it is implicit.

and only if flashinfer autotune will run through without issues.
"""
if (vllm_config.parallel_config.data_parallel_size > 1 and (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8)):
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error I am seeing in this case,

(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:33 [parallel_state.py:1160] Adjusting world_size=2 rank=0 distributed_init_method=tcp://127.0.0.1:46347 for DP
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:33 [__init__.py:1147] Found nccl from library libnccl.so.2
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:33 [pynccl.py:108] vLLM is using nccl==2.27.3
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [__init__.py:1147] Found nccl from library libnccl.so.2
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [pynccl.py:108] vLLM is using nccl==2.27.3
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [cuda_communicator.py:113] Using DeepEP High-Throughput all2all manager.
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [parallel_state.py:1325] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [topk_topp_sampler.py:58] Using FlashInfer for top-p & top-k sampling.
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [gpu_model_runner.py:2849] Starting to load model openai/gpt-oss-20b...
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:35 [gpu_model_runner.py:2879] Loading model from scratch...
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:36 [cuda.py:382] Using FlashInfer backend with HND KV cache layout on V1 engine by default for Blackwell (SM 10.0) GPUs.
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:36 [mxfp4.py:97] Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, for high concurrency throughput workloads consider setting VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better performance
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:36 [layer.py:1140] [EP Rank 0/2] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 16/32. Experts local to global index map: 0->0, 1->1, 2->2, 3->3, 4->4, 5->5, 6->6, 7->7, 8->8, 9->9, 10->10, 11->11, 12->12, 13->13, 14->14, 15->15.
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:37 [weight_utils.py:419] Using model weights format ['*.safetensors']
(EngineCore_DP0 pid=2243834) 
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:40 [default_loader.py:314] Loading weights took 2.87 seconds
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:42 [gpu_model_runner.py:2911] Model loading took 8.7772 GiB and 6.134084 seconds
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:43 [gpu_worker.py:315] Available KV cache memory: 147.17 GiB
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:43 [kv_cache_utils.py:1199] GPU KV cache size: 3,214,864 tokens
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:43 [kv_cache_utils.py:1204] Maximum concurrency for 131,072 tokens per request: 46.12x
(EngineCore_DP0 pid=2243834) INFO 10-13 15:34:43 [utils.py:359] `_KV_CACHE_LAYOUT_OVERRIDE` variable detected. Setting KV cache layout to HND.
(EngineCore_DP0 pid=2243834) 2025-10-13 15:34:43,962 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
(EngineCore_DP0 pid=2243834) 2025-10-13 15:34:44,111 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790] EngineCore failed to start.
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790] Traceback (most recent call last):
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/engine/core.py", line 777, in run_engine_core
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     engine_core = DPEngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/engine/core.py", line 1072, in __init__
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     super().__init__(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/engine/core.py", line 553, in __init__
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     super().__init__(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/engine/core.py", line 110, in __init__
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/engine/core.py", line 237, in _initialize_kv_caches
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/executor/abstract.py", line 78, in initialize_from_config
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/executor/uniproc_executor.py", line 74, in collective_rpc
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/utils/__init__.py", line 2977, in run_method
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return func(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/worker/gpu_worker.py", line 358, in compile_or_warm_up_model
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     kernel_warmup(self)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/warmup/kernel_warmup.py", line 52, in kernel_warmup
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     flashinfer_autotune(worker.model_runner)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/warmup/kernel_warmup.py", line 97, in flashinfer_autotune
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     runner._dummy_run(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return func(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 3456, in _dummy_run
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     outputs = self.model(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]               ^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/models/gpt_oss.py", line 692, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/compilation/decorators.py", line 260, in __call__
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self.forward(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/models/gpt_oss.py", line 282, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     x, residual = layer(x, positions, residual)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/models/gpt_oss.py", line 221, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     output = self.mlp(hidden_states)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]              ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/models/gpt_oss.py", line 179, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     x = self.experts(hidden_states=x, router_logits=g)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/custom_op.py", line 46, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._forward_method(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 2051, in forward_cuda
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self.forward_native(hidden_states, router_logits)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 2026, in forward_native
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     fused_output = torch.ops.vllm.moe_forward(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._op(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 2367, in moe_forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self.forward_impl(hidden_states, router_logits)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 2239, in forward_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     final_hidden_states = self.quant_method.apply(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                           ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/quantization/mxfp4.py", line 931, in apply
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._route_and_experts(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/quantization/mxfp4.py", line 891, in _route_and_experts
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self.fused_experts(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1165, in forward
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     fused_out = self._fused_experts(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                 ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1018, in _fused_experts
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     self.fused_experts.apply(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/trtllm_moe.py", line 162, in apply
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     trtllm_fp4_block_scale_routed_moe(**kwargs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/fused_moe/core.py", line 1845, in trtllm_fp4_block_scale_routed_moe
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/fused_moe/core.py", line 1346, in trtllm_fp4_block_scale_moe_op
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     _, tactic = tuner.choose_one(
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                 ^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/autotuner.py", line 457, in choose_one
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     profiles = self._generate_optimization_profiles(tuning_config, inputs)
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/autotuner.py", line 643, in _generate_optimization_profiles
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]     assert len(opt_shapes) > 0, "Empty tuning buckets are not allowed"
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790]            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2243834) ERROR 10-13 15:34:44 [core.py:790] AssertionError: Empty tuning buckets are not allowed

My plan is to create an issue for this, when this PR lands.
cc @mgoin

@varun-sundar-rabindranath
Copy link
Contributor Author

cc @mgoin @bnellnm @zyongye PTAL! Thanks 🙌

Comment on lines +28 to +39
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
"""
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
return not (
vllm_config.parallel_config.data_parallel_size > 1
and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a skip failling test case to tests/quantization/test_blackwell_moe.py to keep track of known failures

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍 Updated blackwell tests to execute these cases. PTAL! Thanks!

@mergify
Copy link

mergify bot commented Oct 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 16, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 20, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Oct 20, 2025
@mgoin mgoin merged commit 5ff5d94 into vllm-project:main Oct 21, 2025
55 checks passed
Zhuul pushed a commit to Zhuul/vllm that referenced this pull request Oct 21, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Oct 21, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
Chenyaaang pushed a commit to Chenyaaang/vllm that referenced this pull request Oct 28, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath could we add this extra requirement?

and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"

In our testing without VLLM_ALL2ALL_BACKEND="deepep_high_throughput", GPT-OSS has no issues

ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants