Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(

// ggml_compute_forward_flash_attn_ext

static void ggml_compute_forward_flash_attn_ext_f16(
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst) {

ggml_tensor * dst,
int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
Expand All @@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

const int ith = params->ith;
const int nth = params->nth;

const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
Expand Down Expand Up @@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(

// parallelize by q rows using ggml_vec_dot_f32

// total rows in q
const int nr = neq1*neq2*neq3;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
Expand All @@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");

int ith = params->ith;

// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
Expand Down Expand Up @@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}

static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];

GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;

GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);

// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));

GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);

GGML_ASSERT(neq1 == N);

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);

// parallelize by q rows using ggml_vec_dot_f32

// total rows in q
const int64_t nr = neq1*neq2*neq3;

// rows per thread
const int ith = params->ith;
const int nth = params->nth;

// disable for NUMA
const bool disable_chunking = ggml_is_numa();

// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;

if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}

if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
}

ggml_barrier(params->threadpool);

// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;

// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;

while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);

ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);

current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}

void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand Down