Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT)
iqk/iqk_flash_attn.cpp
iqk/fa/iqk_fa_576_512.cpp
iqk/fa/iqk_fa_192_128.cpp
iqk/fa/iqk_fa_192_192.cpp
iqk/fa/iqk_fa_256_256.cpp
iqk/fa/iqk_fa_128_128.cpp
iqk/fa/iqk_fa_96_96.cpp
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-cuda/fattn-mma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
case 128:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
Expand Down Expand Up @@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
auto K = dst->src[1];
auto V = dst->src[1];
if (K->ne[0] != V->ne[0]) return false;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
}
38 changes: 38 additions & 0 deletions ggml/src/ggml-cuda/fattn-new-mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
}
};

template <>
struct fattn_mma_f16_config<192, 192> {
static constexpr int nbatch_fa = 64;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 1;

static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}

static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}

static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}

static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}

static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}

static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};

template <>
struct fattn_mma_f16_config<576, 512> {
static constexpr int nbatch_fa = 32;
Expand Down Expand Up @@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
return;
}
if (K->ne[0] == 192 && V->ne[0] == 192) {
GGML_ASSERT(Q->ne[0] == 192);
GGML_ASSERT(gqa_ratio == 1);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
Expand Down
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_192_192.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_192_192) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);

}

#endif
1 change: 1 addition & 0 deletions ggml/src/iqk/fa/iqk_fa_templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,

IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_192_128);
IQK_FA_CASE(iqk_fa_192_192);
IQK_FA_CASE(iqk_fa_256_256);
IQK_FA_CASE(iqk_fa_128_128);
IQK_FA_CASE(iqk_fa_96_96);
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}

if (Dk == 192 && Dv == 192) {
return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}

if (Dk == 256 && Dv == 256) {
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
Expand Down
2 changes: 1 addition & 1 deletion src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5931,7 +5931,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;

bool is_lite = (hparams.n_layer == 27);
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);

// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
Expand Down
4 changes: 3 additions & 1 deletion src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,10 @@ void llm_load_hparams(
for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv;
hparams.n_embd_head_k = 192;
hparams.n_embd_head_v = 128;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v);
}
bool is_lite = (hparams.n_layer == 27);
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
if (!is_lite) {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-load-tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,7 @@ bool create_tensors_helper::create_arctix_tensors(const LLM_TN & tn) {
bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE

const bool is_lite = (hparams.n_layer == 27);
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);

const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
Expand Down