Skip to content

Commit 3c5f872

Browse files
ikawrakowIwan Kawrakow
andauthored
More Flash Attention improvements (#173)
* FA: slightly faster V*softmax(K*Q)) on Zen4 * FA: it is also faster on AVX2 and ARM_NEON * Deleted forgotten commented out code * FA: slightly faster V*softmax(K*Q)) also for fp16 K-cache * FA: slightly faster V*softmax(K*Q)) on Zen4 We now get 130.9 t/s for a context of 32k tokens. * FA: don't store sum scaling factor in SIMD registers * FA: timing * FA: faster q8_0 cache via run-time-repacking On Zen4 q8_0 KV-cache now slightly outperforms BF16. We get 134 t/s for 32k tokens, which is ~30% better than the main branch, and ~18% better than the last commit. We simply repack the K-cache to q8_0_r4 before the K*Q multiplication and use the q8_0_r4 x q8_0_x4 matrix multiplication template. * FA: Fix AVX2 * FA: fix ARN_NEON * FA: vectorize q8_0 -> q8_0_r4 repacking also on NEON * FA: dedicated mat mul for D = 128 also for ARM_NEON * FA: turn off performance timer --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0b74397 commit 3c5f872

2 files changed

Lines changed: 841 additions & 203 deletions

File tree

ggml/src/ggml.c

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17471,25 +17471,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1747117471

1747217472
#if GGML_USE_IQK_MULMAT
1747317473
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
17474-
int64_t work_per_slice = D*nek1*neq1;
17475-
int ntg = 1;
17474+
// I keep changing my mind what is the best strategy to split the threads when processing
17475+
// multiple heads. This is my current thinking, the commented out code below was the previous.
17476+
int ntg = nth/simple_gcd(neq2*neq3, nth);
17477+
int64_t neq1g = (neq1 + ntg - 1)/ntg;
17478+
//int64_t work_per_slice = D*nek1*neq1;
17479+
//int ntg = 1;
1747617480
//
1747717481
// When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
1747817482
// But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
1747917483
// the number of threads processing the (iq2, iq3) matrix.
1748017484
//
17481-
if (neq1 >= 8*nth) {
17482-
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17483-
else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17484-
else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17485-
}
17485+
//if (neq1 >= 8*nth) {
17486+
// if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17487+
// else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17488+
// else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17489+
//}
1748617490
int counter = 0;
1748717491
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
1748817492
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
1748917493
if (counter++ % (nth/ntg) == ith/ntg) {
17490-
int iq1 = (ith%ntg)*neq1/ntg;
17494+
int iq1 = (ith%ntg)*neq1g;
17495+
int this_neq1 = MIN(neq1g, neq1-iq1);
1749117496
if (!iqk_flash_attn_noalibi(k->type, v->type,
17492-
D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17497+
D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
1749317498
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
1749417499
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
1749517500
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),

0 commit comments

Comments
 (0)