Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,10 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * sinks);

GGML_API void ggml_flash_attn_ext_add_bounds(
struct ggml_tensor * a,
struct ggml_tensor * bounds);

// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
Expand Down
21 changes: 20 additions & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -8993,6 +8993,22 @@ void ggml_flash_attn_ext_add_sinks(
a->src[4] = sinks;
}

void ggml_flash_attn_ext_add_bounds(
struct ggml_tensor * a,
struct ggml_tensor * bounds) {
if (!bounds) {
a->src[5] = NULL;
return;
}

GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
GGML_ASSERT(bounds->type == GGML_TYPE_I32);
GGML_ASSERT(bounds->ne[0] == 2);
GGML_ASSERT(bounds->ne[1] >= a->src[0]->ne[1]);

a->src[5] = bounds;
}

// ggml_flash_attn_back

struct ggml_tensor * ggml_flash_attn_back(
Expand Down Expand Up @@ -18661,6 +18677,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const struct ggml_tensor * v = dst->src[2];
const struct ggml_tensor * mask = dst->src[3];
const struct ggml_tensor * sinks = dst->src[4];
const struct ggml_tensor * bounds= dst->src[5];

GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
Expand Down Expand Up @@ -18739,7 +18756,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
dst->ne[2], dst->ne[1], dst->nb[1],
k->type, v->type,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
q->data, k->data, v->data, mask->data,
sinks ? sinks->data : NULL,
bounds ? bounds->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;

Expand Down
68 changes: 50 additions & 18 deletions ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
}
}
inline std::pair<int, int> mask_range(int nek1, const uint16_t * umask) {
int first_k = 0, last_k = nek1;
for (; first_k < last_k; ++first_k) {
if (umask[first_k] == 0) break;
}
for (; last_k > first_k; --last_k) {
if (umask[last_k-1] == 0) break;
}
return { first_k, last_k };
}
inline bool reduce_k_range(int nek1, int& first_k, int& last_k) {
int nk = last_k - first_k;
if (nk >= nek1) return false;
if (nk%32) {
int nk32 = 32*((nk + 31)/32);
int diff = nk32 - nk;
first_k = std::max(0, first_k - diff);
last_k = first_k + nk32;
}
return last_k - first_k < nek1;
}
}

// TODO: get the ggml_type enum here without polution
Expand All @@ -66,7 +87,8 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // attention sinks
const void * bounds, // attention mask bounds
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
Expand All @@ -80,22 +102,13 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rk3 = neq3/nek3;
int rv3 = neq3/nev3;

int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determinbe the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
// we get non-negligible performance gains for SWA models and long context.
auto umask = (const uint16_t *)mask;
for (; first_k < last_k; ++first_k) {
if (umask[first_k] == 0) break;
}
for (; last_k > first_k; --last_k) {
if (umask[last_k-1] == 0) break;
}
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
bool range_found = false;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && bounds && nek1 > 32) {
range_found = true;
auto b = (const int32_t *)bounds;
int first_k = b[0];
int last_k = b[1];
if ((last_k - first_k)%32 == 0) { // why is this not better? : if (reduce_k_range(nek1, first_k, last_k)) {
k = (const void *)((const char *)k + first_k*stride_k);
v = (const void *)((const char *)v + first_k*stride_v);
mask = (const void *)((const uint16_t *)mask + first_k);
Expand All @@ -105,7 +118,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float

int int_type_k = int_type_k_in;
auto work_buffer = work_buffer_in;
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
if (neq1 >= 8 || (false && rk2 >= 8 && nek2 > 1)) {
uint64_t row_size = 0;
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
if (int_type_k != int_type_k_in) {
Expand Down Expand Up @@ -299,6 +312,25 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1g;
int this_neq1 = std::min(neq1g, neq1-iq1);
if (bounds && !range_found) {
auto b = (const int32_t *)bounds + 2*iq1;
int kmin = nek1, kmax = 0;
for (int i = 0; i < this_neq1; ++i) {
kmin = std::min(kmin, b[2*i+0]);
kmax = std::max(kmax, b[2*i+1]);
}
if (reduce_k_range(nek1, kmin, kmax)) {
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, this_neq1, kmax-kmin, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3 + kmin*stride_k),
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3 + kmin*stride_v),
(const void *)((const char *)mask + iq1*stride_m + kmin*sizeof(uint16_t)), sinksf, 1,
scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
continue;
}
}
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/iqk/iqk_mul_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // attention sinks
const void * bounds, // attention mask bounds
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
Expand Down
Loading