From 38a69358d63ecc589841f91b647758278850e0dd Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sat, 28 Mar 2026 03:15:07 +0530 Subject: [PATCH 1/3] Reduce the number of stream-k blocks to reduce the overhead of the flash_attn_stream_k_fixup kernel --- ggml/src/ggml-cuda/fattn-common.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c59a4db3999..5d8a7f27bd8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -980,6 +980,16 @@ void launch_fattn( const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; + //Todo: need to find a thresold based on tuning + constexpr int thr_blocks_stream_k = 16; + + // try reducing the number of stream-k blocks as + // flash_attn_stream_k_fixup has a non-negligible overhead for large number of stream-k blocks + // make sure to reduce only when more than 1 block per SM is used + if(use_stream_k && nblocks_stream_k / ntiles_dst > thr_blocks_stream_k && max_blocks_per_sm > 1) { + nblocks_stream_k = nblocks_stream_k / 2; + } + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; From 244f50d5fd9c0a808f2a505cb61b623322b157fb Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sat, 28 Mar 2026 03:49:51 +0530 Subject: [PATCH 2/3] Fix compilation error --- ggml/src/ggml-cuda/fattn-common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5d8a7f27bd8..779883e7d17 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -976,7 +976,7 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); + int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; From 4a2f0179204366ada5820d509ac994fbbbc23132 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sat, 28 Mar 2026 03:53:54 +0530 Subject: [PATCH 3/3] Remove trailing whitespace --- ggml/src/ggml-cuda/fattn-common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 779883e7d17..ea71cbad8b2 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -983,7 +983,7 @@ void launch_fattn( //Todo: need to find a thresold based on tuning constexpr int thr_blocks_stream_k = 16; - // try reducing the number of stream-k blocks as + // try reducing the number of stream-k blocks as // flash_attn_stream_k_fixup has a non-negligible overhead for large number of stream-k blocks // make sure to reduce only when more than 1 block per SM is used if(use_stream_k && nblocks_stream_k / ntiles_dst > thr_blocks_stream_k && max_blocks_per_sm > 1) {