Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
int rk2 = q->ne[2]/k->ne[2];
if (rk2%nth_k == 0) {
size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
if (ggml_is_quantized(k->type)) {
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
cur = MAX(cur, size);
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
if (ggml_is_quantized(k->type)) {
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
cur = MAX(cur, size);
}
}
#endif
Expand Down
60 changes: 41 additions & 19 deletions ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,40 +64,63 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
if (rk2%nth_k == 0) {
int ith_k = ith%gcd_k;
int ith_q = ith/gcd_k;
int ith_k = ith%gcd_k;
int ith_q = ith/gcd_k;
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
if (nq_per_thread > 1) {
int ith_mid = nth_k;
int nq_this_thread = nq_per_thread;
if (nq_per_thread*nth_k > rk2) {
ith_mid = rk2 - nth_k*(nq_per_thread - 1);
if (ith_q >= ith_mid) --nq_this_thread;
}
int j_mid = ith_mid*nq_per_thread;
auto work = (char *)work_buffer;
auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float);
auto result_buffer = work;

auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k;
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2;
auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
auto qth = (const char *)q + q_offset;
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
auto work = (char *)work_buffer;

// Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float)
// In addition, we need M, S for the rk2/nth_k rows the thread is processing
// => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not
// Each thread will produce a result of size Dv*nq_this_thread*sizeof(float)
// In addition, we need M, S for the nq_this_thread rows the thread is processing
// => (Dv + 2)*nq_per_thread*sizeof(float). We use (Dv + 16) instead to make sure threads are not
// writing onto the same cache line.
auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
auto result_buffer = work;
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false;
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;

barrier(barrier_data);

// There are nek1/gcd_k contributions for each j that we need to sum up
// Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread

// TODO: simdify this
// TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now
for (int j = ith; j < rk2; j += nth) {
auto Racc = qkv + j*nb1/sizeof(float);
float M = -INFINITY, S = 0;
int jth_q = j/(rk2/nth_k);
int jj = j%(rk2/nth_k);
for (int j1 = 0; j1 < rk2/nth_k; ++j1) {
auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread);
auto Mj = R + Dv*rk2/nth_k;
auto Sj = Mj + rk2/nth_k;
int jth_first, jj, nq_this_j;
if (j < j_mid) {
jth_first = j/nq_per_thread;
jj = j%nq_per_thread;
nq_this_j = nq_per_thread;
} else {
jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1);
jj = (j - j_mid)%(nq_per_thread-1);
nq_this_j = nq_per_thread - 1;
}
jth_first *= gcd_k;
for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) {
auto R = (const float *)(result_buffer + jth*size_thread);
auto Mj = R + Dv*nq_this_j;
auto Sj = Mj + nq_this_j;
R += jj*Dv;
if (Mj[jj] == -INFINITY) continue;
if (Mj[jj] > M) {
Expand All @@ -120,7 +143,6 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;

}
}
}
Expand Down