Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.

// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
// by the VKQ accumulators is effectively being shifted up by a factor of 8.
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
#define FATTN_KQ_MAX_OFFSET 0.6931f

typedef void (* fattn_kernel_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
Expand Down Expand Up @@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
// Turing + Volta:
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;

KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
}
}

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec(
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
}

KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);

if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16(

KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
}
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);

Expand Down
Loading