Skip to content

Commit 906c2f5

Browse files
yzh119comaniac
andauthored
feat: append attention kernels for fp8 kv-cache (#420)
This implementation do not rely on fp8 tensor cores, but uses fp16 tensor cores instead, the fp8 kv-cache will be dequantized on-the-fly. sm_89 and sm_90 append attention kernels that uses native fp8 tensor cores will be available in later PRs. --------- Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent b781513 commit 906c2f5

24 files changed

Lines changed: 1670 additions & 822 deletions

CMakeLists.txt

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ set (IDTYPES "i32")
9191
if(FLASHINFER_ENABLE_FP8)
9292
list(APPEND DECODE_DTYPES "e4m3" "e5m2")
9393
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2")
94+
list(APPEND PREFILL_FP8_DTYPES "e4m3" "e5m2")
9495
endif(FLASHINFER_ENABLE_FP8)
9596

9697
if(FLASHINFER_ENABLE_BF16)
@@ -194,7 +195,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
194195
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
195196
foreach(mask_mode IN LISTS MASK_MODES)
196197
foreach(dtype IN LISTS PREFILL_DTYPES)
197-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu)
198+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
198199
add_custom_command(
199200
OUTPUT ${generated_kernel_src}
200201
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
@@ -204,6 +205,18 @@ foreach(head_dim IN LISTS HEAD_DIMS)
204205
)
205206
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
206207
endforeach(dtype)
208+
209+
foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
210+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
211+
add_custom_command(
212+
OUTPUT ${generated_kernel_src}
213+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
214+
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py
215+
COMMENT "Generating additional source file ${generated_kernel_src}"
216+
VERBATIM
217+
)
218+
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
219+
endforeach(dtype_kv)
207220
endforeach(mask_mode)
208221
endforeach(allow_fp16_qk_reduction)
209222
endforeach(pos_encoding_mode)
@@ -216,9 +229,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
216229
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
217230
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
218231
foreach(mask_mode IN LISTS MASK_MODES)
219-
foreach(dtype IN LISTS PREFILL_DTYPES)
220-
foreach(idtype IN LISTS IDTYPES)
221-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
232+
foreach(idtype IN LISTS IDTYPES)
233+
foreach(dtype IN LISTS PREFILL_DTYPES)
234+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
222235
add_custom_command(
223236
OUTPUT ${generated_kernel_src}
224237
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
@@ -227,8 +240,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
227240
VERBATIM
228241
)
229242
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
230-
endforeach(idtype)
231-
endforeach(dtype)
243+
endforeach(dtype)
244+
245+
foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
246+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
247+
add_custom_command(
248+
OUTPUT ${generated_kernel_src}
249+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
250+
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
251+
COMMENT "Generating additional source file ${generated_kernel_src}"
252+
VERBATIM
253+
)
254+
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
255+
endforeach(dtype_kv)
256+
endforeach(idtype)
232257
endforeach(mask_mode)
233258
endforeach(allow_fp16_qk_reduction)
234259
endforeach(pos_encoding_mode)
@@ -241,9 +266,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
241266
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
242267
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
243268
foreach(mask_mode IN LISTS MASK_MODES)
244-
foreach(dtype IN LISTS PREFILL_DTYPES)
245-
foreach(idtype IN LISTS IDTYPES)
246-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
269+
foreach(idtype IN LISTS IDTYPES)
270+
foreach(dtype IN LISTS PREFILL_DTYPES)
271+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
247272
add_custom_command(
248273
OUTPUT ${generated_kernel_src}
249274
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
@@ -252,8 +277,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
252277
VERBATIM
253278
)
254279
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
255-
endforeach(idtype)
256-
endforeach(dtype)
280+
endforeach(dtype)
281+
282+
foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
283+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
284+
add_custom_command(
285+
OUTPUT ${generated_kernel_src}
286+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
287+
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py
288+
COMMENT "Generating additional source file ${generated_kernel_src}"
289+
VERBATIM
290+
)
291+
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
292+
endforeach(dtype_kv)
293+
endforeach(idtype)
257294
endforeach(mask_mode)
258295
endforeach(allow_fp16_qk_reduction)
259296
endforeach(pos_encoding_mode)

0 commit comments

Comments
 (0)