Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rk3 = neq3/nek3;
int rv3 = neq3/nev3;

int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determinbe the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
// we get non-negligible performance gains for SWA models and long context.
auto umask = (const uint16_t *)mask;
for (; first_k < last_k; ++first_k) {
if (umask[first_k] == 0) break;
}
for (; last_k > first_k; --last_k) {
if (umask[last_k-1] == 0) break;
}
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
k = (const void *)((const char *)k + first_k*stride_k);
v = (const void *)((const char *)v + first_k*stride_v);
mask = (const void *)((const uint16_t *)mask + first_k);
nek1 = last_k - first_k;
}
}

int int_type_k = int_type_k_in;
auto work_buffer = work_buffer_in;
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
Expand Down