|
26 | 26 | from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous |
27 | 27 |
|
28 | 28 |
|
| 29 | +def _generate_optimal_warmup_m_values( |
| 30 | + max_tokens: int, n: int, device: torch.device |
| 31 | +) -> list[int]: |
| 32 | + """ |
| 33 | + Generate M values that cover all possible DeepGEMM kernel configurations. |
| 34 | + Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp |
| 35 | +
|
| 36 | + Args: |
| 37 | + max_tokens: Maximum number of tokens to warmup for |
| 38 | + n: The actual N dimension from the weight tensor |
| 39 | + device: The torch device to get properties from. |
| 40 | + """ |
| 41 | + |
| 42 | + def ceil_div(a: int, b: int) -> int: |
| 43 | + return (a + b - 1) // b |
| 44 | + |
| 45 | + # DeepGEMM's possible block sizes |
| 46 | + block_ms = [64, 128, 256] |
| 47 | + block_ns = list(range(16, min(257, n + 1), 16)) |
| 48 | + num_sms = torch.cuda.get_device_properties(device).multi_processor_count |
| 49 | + |
| 50 | + m_values = set() |
| 51 | + |
| 52 | + # Always include small cases |
| 53 | + m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)]) |
| 54 | + |
| 55 | + # Collect M values where different wave patterns occur |
| 56 | + for block_m in block_ms: |
| 57 | + for block_n in block_ns: |
| 58 | + if block_n > n: |
| 59 | + continue |
| 60 | + |
| 61 | + # Add key M boundaries for this block combination |
| 62 | + for wave in range(1, 11): # Up to 10 waves |
| 63 | + # M where this block config transitions to next wave |
| 64 | + target_blocks = wave * num_sms |
| 65 | + m = target_blocks * block_m // ceil_div(n, block_n) |
| 66 | + if 1 <= m <= max_tokens: |
| 67 | + m_values.add(m) |
| 68 | + |
| 69 | + # Add block_m boundaries |
| 70 | + for multiple in range(1, max_tokens // block_m + 1): |
| 71 | + m = multiple * block_m |
| 72 | + if m <= max_tokens: |
| 73 | + m_values.add(m) |
| 74 | + |
| 75 | + return sorted(m_values) |
| 76 | + |
| 77 | + |
29 | 78 | def _extract_data_from_linear_base_module( |
30 | 79 | m: torch.nn.Module, |
31 | 80 | ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: |
@@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: |
136 | 185 | ) |
137 | 186 | out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) |
138 | 187 |
|
139 | | - pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") |
140 | | - num_tokens = max_tokens |
141 | | - while num_tokens > 0: |
| 188 | + # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". |
| 189 | + # Otherwise warmup all token sizes to avoid JIT compilation in hotpath |
| 190 | + if envs.VLLM_DEEP_GEMM_WARMUP == "relax": |
| 191 | + m_values = _generate_optimal_warmup_m_values(max_tokens, n, device) |
| 192 | + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]" |
| 193 | + else: |
| 194 | + assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( |
| 195 | + "Expected " |
| 196 | + 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' |
| 197 | + f"{envs.VLLM_DEEP_GEMM_WARMUP}" |
| 198 | + ) |
| 199 | + m_values = list(range(1, max_tokens + 1)) |
| 200 | + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]" |
| 201 | + |
| 202 | + pbar = tqdm(total=len(m_values), desc=desc) |
| 203 | + |
| 204 | + for num_tokens in m_values: |
142 | 205 | fp8_gemm_nt( |
143 | 206 | (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] |
144 | 207 | ) |
145 | 208 | pbar.update(1) |
146 | | - num_tokens -= 1 |
147 | 209 |
|
148 | 210 | FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) |
149 | 211 |
|
@@ -195,20 +257,23 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): |
195 | 257 | ) |
196 | 258 | out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) |
197 | 259 |
|
| 260 | + # Generate M values in block_m increments (already optimized for MoE) |
| 261 | + m_values = list(range(block_m, MAX_M + 1, block_m)) |
| 262 | + |
198 | 263 | pbar = tqdm( |
199 | | - total=MAX_BLOCKS, |
200 | | - desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", |
| 264 | + total=len(m_values), |
| 265 | + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) " |
| 266 | + f"[{len(m_values)} values, block_m={block_m}]", |
201 | 267 | ) |
202 | | - num_tokens = MAX_M |
203 | | - while num_tokens > 0: |
| 268 | + |
| 269 | + for num_tokens in m_values: |
204 | 270 | m_grouped_fp8_gemm_nt_contiguous( |
205 | 271 | (a1q[:num_tokens], a1q_scales[:num_tokens]), |
206 | 272 | (w, w_scale), |
207 | 273 | out[:num_tokens], |
208 | 274 | expert_ids[:num_tokens], |
209 | 275 | ) |
210 | 276 | pbar.update(1) |
211 | | - num_tokens = num_tokens - block_m |
212 | 277 |
|
213 | 278 | for w, ws in [(w1, w1_scale), (w2, w2_scale)]: |
214 | 279 | if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: |
|
0 commit comments