Skip to content

Commit 6f43a35

Browse files
mgoingemini-code-assist[bot]Varun Sundar Rabindranath
authored andcommitted
[UX] Speedup DeepGEMM warmup with heuristics (vllm-project#25619)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Michael Goin <[email protected]> Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 092a9c0 commit 6f43a35

File tree

3 files changed

+95
-14
lines changed

3 files changed

+95
-14
lines changed

vllm/envs.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@
146146
VLLM_TPU_USING_PATHWAYS: bool = False
147147
VLLM_USE_DEEP_GEMM: bool = True
148148
VLLM_USE_DEEP_GEMM_E8M0: bool = True
149-
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
149+
VLLM_DEEP_GEMM_WARMUP: Literal[
150+
"skip",
151+
"full",
152+
"relax",
153+
] = "relax"
150154
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
151155
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
152156
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
@@ -1088,9 +1092,21 @@ def get_vllm_port() -> int | None:
10881092
# JIT all the required kernels before model execution so there is no
10891093
# JIT'ing in the hot-path. However, this warmup increases the engine
10901094
# startup time by a couple of minutes.
1091-
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
1092-
"VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(
1093-
int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))
1095+
# Available options:
1096+
# - "skip" : Skip warmup.
1097+
# - "full" : Warmup deepgemm by running all possible gemm shapes the
1098+
# engine could encounter.
1099+
# - "relax" : Select gemm shapes to run based on some heuristics. The
1100+
# heuristic aims to have the same effect as running all possible gemm
1101+
# shapes, but provides no guarantees.
1102+
"VLLM_DEEP_GEMM_WARMUP": env_with_choices(
1103+
"VLLM_DEEP_GEMM_WARMUP",
1104+
"relax",
1105+
[
1106+
"skip",
1107+
"full",
1108+
"relax",
1109+
],
10941110
),
10951111
# Whether to use fused grouped_topk used for MoE expert selection.
10961112
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(

vllm/model_executor/warmup/deep_gemm_warmup.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,55 @@
2626
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
2727

2828

29+
def _generate_optimal_warmup_m_values(
30+
max_tokens: int, n: int, device: torch.device
31+
) -> list[int]:
32+
"""
33+
Generate M values that cover all possible DeepGEMM kernel configurations.
34+
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp
35+
36+
Args:
37+
max_tokens: Maximum number of tokens to warmup for
38+
n: The actual N dimension from the weight tensor
39+
device: The torch device to get properties from.
40+
"""
41+
42+
def ceil_div(a: int, b: int) -> int:
43+
return (a + b - 1) // b
44+
45+
# DeepGEMM's possible block sizes
46+
block_ms = [64, 128, 256]
47+
block_ns = list(range(16, min(257, n + 1), 16))
48+
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
49+
50+
m_values = set()
51+
52+
# Always include small cases
53+
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])
54+
55+
# Collect M values where different wave patterns occur
56+
for block_m in block_ms:
57+
for block_n in block_ns:
58+
if block_n > n:
59+
continue
60+
61+
# Add key M boundaries for this block combination
62+
for wave in range(1, 11): # Up to 10 waves
63+
# M where this block config transitions to next wave
64+
target_blocks = wave * num_sms
65+
m = target_blocks * block_m // ceil_div(n, block_n)
66+
if 1 <= m <= max_tokens:
67+
m_values.add(m)
68+
69+
# Add block_m boundaries
70+
for multiple in range(1, max_tokens // block_m + 1):
71+
m = multiple * block_m
72+
if m <= max_tokens:
73+
m_values.add(m)
74+
75+
return sorted(m_values)
76+
77+
2978
def _extract_data_from_linear_base_module(
3079
m: torch.nn.Module,
3180
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
@@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
136185
)
137186
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
138187

139-
pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
140-
num_tokens = max_tokens
141-
while num_tokens > 0:
188+
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
189+
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
190+
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
191+
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
192+
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
193+
else:
194+
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
195+
"Expected "
196+
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
197+
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
198+
)
199+
m_values = list(range(1, max_tokens + 1))
200+
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
201+
202+
pbar = tqdm(total=len(m_values), desc=desc)
203+
204+
for num_tokens in m_values:
142205
fp8_gemm_nt(
143206
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
144207
)
145208
pbar.update(1)
146-
num_tokens -= 1
147209

148210
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
149211

@@ -195,20 +257,23 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
195257
)
196258
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
197259

260+
# Generate M values in block_m increments (already optimized for MoE)
261+
m_values = list(range(block_m, MAX_M + 1, block_m))
262+
198263
pbar = tqdm(
199-
total=MAX_BLOCKS,
200-
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})",
264+
total=len(m_values),
265+
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
266+
f"[{len(m_values)} values, block_m={block_m}]",
201267
)
202-
num_tokens = MAX_M
203-
while num_tokens > 0:
268+
269+
for num_tokens in m_values:
204270
m_grouped_fp8_gemm_nt_contiguous(
205271
(a1q[:num_tokens], a1q_scales[:num_tokens]),
206272
(w, w_scale),
207273
out[:num_tokens],
208274
expert_ids[:num_tokens],
209275
)
210276
pbar.update(1)
211-
num_tokens = num_tokens - block_m
212277

213278
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
214279
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:

vllm/model_executor/warmup/kernel_warmup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"):
2929
do_deep_gemm_warmup = (
3030
envs.VLLM_USE_DEEP_GEMM
3131
and is_deep_gemm_supported()
32-
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP
32+
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
3333
)
3434
if do_deep_gemm_warmup:
3535
model = worker.get_model()

0 commit comments

Comments
 (0)