-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[Bugfix] Fix gpt-oss w4a8 DP/EP on B200 #26729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup | ||
| from vllm.platforms import current_platform | ||
|
|
@@ -24,6 +25,20 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool: | ||
| """ | ||
| Record known issues with vllm + flashinfer autotune here. Return True if | ||
| and only if flashinfer autotune will run through without issues. | ||
| """ | ||
| return not ( | ||
| vllm_config.parallel_config.data_parallel_size > 1 | ||
| and ( | ||
| envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 | ||
| or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | ||
| ) | ||
| ) | ||
|
Comment on lines
+28
to
+39
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a skip failling test case to tests/quantization/test_blackwell_moe.py to keep track of known failures
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done 👍 Updated blackwell tests to execute these cases. PTAL! Thanks! |
||
|
|
||
|
|
||
| def kernel_warmup(worker: "Worker"): | ||
| # Deep GEMM warmup | ||
| do_deep_gemm_warmup = ( | ||
|
|
@@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"): | |
| deep_gemm_warmup(model, max_tokens) | ||
|
|
||
| # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs | ||
| if has_flashinfer() and current_platform.has_device_capability(90): | ||
| if ( | ||
| has_flashinfer() | ||
| and current_platform.has_device_capability(90) | ||
| and flashinfer_autotune_supported(worker.vllm_config) | ||
| ): | ||
| flashinfer_autotune(worker.model_runner) | ||
|
|
||
| # FlashInfer attention warmup | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath could we add this extra requirement?
In our testing without
VLLM_ALL2ALL_BACKEND="deepep_high_throughput", GPT-OSS has no issues