@@ -91,6 +91,7 @@ set (IDTYPES "i32")
9191if (FLASHINFER_ENABLE_FP8)
9292 list (APPEND DECODE_DTYPES "e4m3" "e5m2" )
9393 list (APPEND DECODE_FP8_DTYPES "e4m3" "e5m2" )
94+ list (APPEND PREFILL_FP8_DTYPES "e4m3" "e5m2" )
9495endif (FLASHINFER_ENABLE_FP8 )
9596
9697if (FLASHINFER_ENABLE_BF16)
@@ -194,7 +195,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
194195 foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
195196 foreach (mask_mode IN LISTS MASK_MODES)
196197 foreach (dtype IN LISTS PREFILL_DTYPES)
197- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/single_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypein_ ${dtype} _dtypeout_${dtype} .cu)
198+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/single_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_ ${dtype} _dtypekv_ ${dtype} _dtypeout_${dtype} .cu)
198199 add_custom_command (
199200 OUTPUT ${generated_kernel_src}
200201 COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_single_prefill_inst.py ${generated_kernel_src}
@@ -204,6 +205,18 @@ foreach(head_dim IN LISTS HEAD_DIMS)
204205 )
205206 list (APPEND single_prefill_kernels_src ${generated_kernel_src} )
206207 endforeach (dtype )
208+
209+ foreach (dtype_kv IN LISTS PREFILL_FP8_DTYPES)
210+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/single_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_f16_dtypekv_${dtype_kv} _dtypeout_f16.cu)
211+ add_custom_command (
212+ OUTPUT ${generated_kernel_src}
213+ COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_single_prefill_inst.py ${generated_kernel_src}
214+ DEPENDS ${PROJECT_SOURCE_DIR} /python/generate_single_prefill_inst.py
215+ COMMENT "Generating additional source file ${generated_kernel_src} "
216+ VERBATIM
217+ )
218+ list (APPEND single_prefill_kernels_src ${generated_kernel_src} )
219+ endforeach (dtype_kv )
207220 endforeach (mask_mode )
208221 endforeach (allow_fp16_qk_reduction )
209222 endforeach (pos_encoding_mode )
@@ -216,9 +229,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
216229 foreach (pos_encoding_mode IN LISTS POS_ENCODING_MODES)
217230 foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
218231 foreach (mask_mode IN LISTS MASK_MODES)
219- foreach (dtype IN LISTS PREFILL_DTYPES )
220- foreach (idtype IN LISTS IDTYPES )
221- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_paged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypein_ ${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
232+ foreach (idtype IN LISTS IDTYPES )
233+ foreach (dtype IN LISTS PREFILL_DTYPES )
234+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_paged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_ ${dtype} _dtypekv_ ${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
222235 add_custom_command (
223236 OUTPUT ${generated_kernel_src}
224237 COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
@@ -227,8 +240,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
227240 VERBATIM
228241 )
229242 list (APPEND batch_paged_prefill_kernels_src ${generated_kernel_src} )
230- endforeach (idtype )
231- endforeach (dtype )
243+ endforeach (dtype )
244+
245+ foreach (dtype_kv IN LISTS PREFILL_FP8_DTYPES)
246+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_paged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_f16_dtypekv_${dtype_kv} _dtypeout_f16_idtype_${idtype} .cu)
247+ add_custom_command (
248+ OUTPUT ${generated_kernel_src}
249+ COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
250+ DEPENDS ${PROJECT_SOURCE_DIR} /python/generate_batch_paged_prefill_inst.py
251+ COMMENT "Generating additional source file ${generated_kernel_src} "
252+ VERBATIM
253+ )
254+ list (APPEND batch_paged_prefill_kernels_src ${generated_kernel_src} )
255+ endforeach (dtype_kv )
256+ endforeach (idtype )
232257 endforeach (mask_mode )
233258 endforeach (allow_fp16_qk_reduction )
234259 endforeach (pos_encoding_mode )
@@ -241,9 +266,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
241266 foreach (pos_encoding_mode IN LISTS POS_ENCODING_MODES)
242267 foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
243268 foreach (mask_mode IN LISTS MASK_MODES)
244- foreach (dtype IN LISTS PREFILL_DTYPES )
245- foreach (idtype IN LISTS IDTYPES )
246- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_ragged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypein_ ${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
269+ foreach (idtype IN LISTS IDTYPES )
270+ foreach (dtype IN LISTS PREFILL_DTYPES )
271+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_ragged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_ ${dtype} _dtypekv_ ${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
247272 add_custom_command (
248273 OUTPUT ${generated_kernel_src}
249274 COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
@@ -252,8 +277,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
252277 VERBATIM
253278 )
254279 list (APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src} )
255- endforeach (idtype )
256- endforeach (dtype )
280+ endforeach (dtype )
281+
282+ foreach (dtype_kv IN LISTS PREFILL_FP8_DTYPES)
283+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated/batch_ragged_prefill_head_${head_dim} _logitshook_${logits_post_hook} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_${mask_mode} _dtypeq_f16_dtypekv_${dtype_kv} _dtypeout_f16_idtype_${idtype} .cu)
284+ add_custom_command (
285+ OUTPUT ${generated_kernel_src}
286+ COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
287+ DEPENDS ${PROJECT_SOURCE_DIR} /python/generate_batch_ragged_prefill_inst.py
288+ COMMENT "Generating additional source file ${generated_kernel_src} "
289+ VERBATIM
290+ )
291+ list (APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src} )
292+ endforeach (dtype_kv )
293+ endforeach (idtype )
257294 endforeach (mask_mode )
258295 endforeach (allow_fp16_qk_reduction )
259296 endforeach (pos_encoding_mode )
0 commit comments