Skip to content

Commit 79bba4f

Browse files
kaixihmgoin
authored andcommitted
[NVIDIA] Add support for cudnn fp4 gemm via flashinfer (vllm-project#26107)
Signed-off-by: kaixih <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: 0xrushi <[email protected]>
1 parent dc6a8bc commit 79bba4f

3 files changed

Lines changed: 57 additions & 38 deletions

File tree

vllm/envs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
192192
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
193193
VLLM_USE_TRTLLM_ATTENTION: str | None = None
194+
VLLM_NVFP4_GEMM_BACKEND: str | None = None
194195
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
195196
VLLM_HAS_FLASHINFER_CUBIN: bool = False
196197
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
@@ -1292,11 +1293,15 @@ def get_vllm_port() -> int | None:
12921293
# If set, it means we pre-downloaded cubin files and flashinfer will
12931294
# read the cubin files directly.
12941295
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
1295-
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
1296-
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
1297-
# vllm cutlass GEMM, marlin GEMM.
1298-
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
1299-
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
1296+
# Supported options:
1297+
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
1298+
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
1299+
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
1300+
# - <none>: automatically pick an available backend
1301+
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
1302+
"VLLM_NVFP4_GEMM_BACKEND",
1303+
None,
1304+
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
13001305
),
13011306
# Controls garbage collection during CUDA graph capture.
13021307
# If set to 0 (default), enables GC freezing to speed up capture time.
@@ -1492,7 +1497,6 @@ def compute_hash() -> str:
14921497
"VLLM_DISABLED_KERNELS",
14931498
"VLLM_USE_DEEP_GEMM",
14941499
"VLLM_USE_DEEP_GEMM_E8M0",
1495-
"VLLM_USE_TRTLLM_FP4_GEMM",
14961500
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
14971501
"VLLM_USE_FLASHINFER_MOE_FP16",
14981502
"VLLM_USE_FLASHINFER_MOE_FP8",
@@ -1524,6 +1528,7 @@ def compute_hash() -> str:
15241528
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
15251529
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
15261530
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
1531+
"VLLM_NVFP4_GEMM_BACKEND",
15271532
"VLLM_USE_FBGEMM",
15281533
]
15291534
for key in environment_variables_to_hash:

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
1515
run_nvfp4_emulations,
1616
)
17-
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
17+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18+
cutlass_fp4_supported,
19+
swizzle_blockscale,
20+
)
1821
from vllm.model_executor.parameter import (
1922
GroupQuantScaleParameter,
2023
ModelWeightParameter,
@@ -29,10 +32,12 @@
2932

3033
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
3134
def __init__(self):
32-
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
33-
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
34-
self.backend = "flashinfer-trtllm"
35-
logger.info_once("Using flashinfer-trtllm for FP4")
35+
self.backend = "none"
36+
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
37+
if has_flashinfer():
38+
self.backend = "flashinfer-cutlass"
39+
elif cutlass_fp4_supported():
40+
self.backend = "cutlass"
3641
elif envs.VLLM_USE_FBGEMM:
3742
self.backend = "fbgemm"
3843
try:
@@ -42,12 +47,17 @@ def __init__(self):
4247
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
4348
"Please install with: pip install fbgemm-gpu-genai"
4449
) from exc
45-
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
46-
elif has_flashinfer():
47-
self.backend = "flashinfer-cutlass"
48-
logger.info_once("Using flashinfer-cutlass for FP4")
49-
else:
50-
self.backend = "cutlass"
50+
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
51+
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
52+
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
53+
54+
if self.backend == "none":
55+
raise ValueError(
56+
"No valid NVFP4 GEMM backend found. "
57+
"Please check your platform capability."
58+
)
59+
60+
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
5161
self.group_size = 16
5262

5363
@classmethod
@@ -184,10 +194,9 @@ def apply_weights(
184194
layer.alpha,
185195
output_dtype,
186196
)
187-
if self.backend == "flashinfer-trtllm":
188-
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
189-
elif self.backend == "flashinfer-cutlass":
190-
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
197+
if self.backend.startswith("flashinfer-"):
198+
backend_name = self.backend[len("flashinfer-") :]
199+
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
191200
elif self.backend == "fbgemm":
192201
out = torch.ops.fbgemm.f4f4bf16(
193202
x_fp4,
@@ -198,6 +207,7 @@ def apply_weights(
198207
use_mx=False,
199208
).to(output_dtype)
200209
else:
210+
assert self.backend == "cutlass"
201211
out = cutlass_scaled_fp4_mm(*mm_args)
202212

203213
if bias is not None:

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
926926
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
927927
self.quant_config = quant_config
928928

929-
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
930-
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
931-
self.backend = "flashinfer-trtllm"
932-
elif has_flashinfer():
933-
self.backend = "flashinfer-cutlass"
934-
elif cutlass_fp4_supported():
935-
self.backend = "cutlass"
936-
elif is_fp4_marlin_supported():
937-
self.backend = "marlin"
938-
else:
929+
self.backend = "none"
930+
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
931+
if has_flashinfer():
932+
self.backend = "flashinfer-cutlass"
933+
elif cutlass_fp4_supported():
934+
self.backend = "cutlass"
935+
elif is_fp4_marlin_supported():
936+
self.backend = "marlin"
937+
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
938+
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
939+
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
940+
941+
if self.backend == "none":
939942
raise ValueError(
940-
"Current platform does not support NVFP4"
941-
" quantization. Please use Blackwell and"
942-
" above."
943+
"No valid NVFP4 GEMM backend found. "
944+
"Please check your platform capability."
943945
)
944946

947+
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
948+
945949
def create_weights(
946950
self,
947951
layer: torch.nn.Module,
@@ -1109,11 +1113,11 @@ def apply(
11091113
layer.alpha,
11101114
output_dtype,
11111115
)
1112-
if self.backend == "flashinfer-trtllm":
1113-
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
1114-
elif self.backend == "flashinfer-cutlass":
1115-
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
1116+
if self.backend.startswith("flashinfer-"):
1117+
backend_name = self.backend[len("flashinfer-") :]
1118+
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
11161119
else:
1120+
assert self.backend == "cutlass"
11171121
out = cutlass_scaled_fp4_mm(*mm_args)
11181122

11191123
if bias is not None:

0 commit comments

Comments
 (0)