[CUDA] Reduce the number of stream-k blocks to reduce the overhead of the flash_attn_stream_k_fixup kernel#21086
[CUDA] Reduce the number of stream-k blocks to reduce the overhead of the flash_attn_stream_k_fixup kernel#21086gaugarg-nv wants to merge 3 commits intoggml-org:masterfrom
Conversation
…ash_attn_stream_k_fixup kernel
|
I don't think this is the correct logic. Asymptotically for an infinitely deep context it should always be worthwhile to run as many CUDA blocks as possible since the overhead becomes negligible. Intuitively I would expect something like this to be a better solution: if possible, always run at least 2 blocks / SM in order to keep the GPU busy when calling |
I agree with what you are saying. But in practice, we are seeing good speed-up even for models that have I will also explore if there is a better way to reduce the overhead of |
|
An RTX 5090 and an RTX Pro 6000 both have 192 SMs with Qwen 3 30b a3b with 4 KV heads that will be on average a 171 chunk / SM. The kernel will run with an internal batch size ( As for the stream-k fixup itself: if the occupancy is low anyways we could maybe run the kernel with a number of CUDA blocks that is an exact multiple of |
I'm not sure if this will help in all cases. As you can see in the perf data, this change is helping even for 32K context length, where
Thanks for the idea. I will look into it. |
For GPUs with high SM counts, the number of stream-k blocks can be very high to fill the entire GPU. In such cases,
flash_attn_stream_k_fixuptakes significant time.The fix is to reduce the number of stream-k blocks. For example, in such cases, if
max_blocks_per_smis 2 or 4, reduce it by a factor of 2. This can reduce occupancy, but I am seeing positive gains with this change.Future work: Explore how to optimize
flash_attn_stream_k_fixupfor a large number of blocks.Performance
This change is also helpful for Tensor parallelism (PR #19378), specifically for gpt-oss, which uses the stream-k path.
Tensor Parallelism Performance on 2x RTX Pro 6000 Blackwell with PR 19378
Additional information
Requirements