From 70f43a7afd816edc828970c68a5a2af6b37e7fca Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 17 Mar 2026 11:38:33 +0000 Subject: [PATCH 1/3] WIP: mistral4 --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-build-context.cpp | 3 ++- src/llama-hparams.cpp | 8 ++++++-- src/llama-load-tensors.cpp | 1 + src/llama-model.cpp | 37 ++++++++++++++++++++++++++++++++++++- src/llama-model.h | 1 + src/llama.cpp | 13 +++++++------ 8 files changed, 55 insertions(+), 10 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 24948c2df2..18a298157b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -76,6 +76,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_GLM_DSA, "glm-dsa" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-arch.h b/src/llama-arch.h index b2c03f7b15..a757a71e58 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -75,6 +75,7 @@ enum llm_arch { LLM_ARCH_SEED_OSS, LLM_ARCH_STEP35, LLM_ARCH_GLM_DSA, + LLM_ARCH_MISTRAL4, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index b07b15a4dc..010f5d4e07 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -133,7 +133,7 @@ ggml_cgraph * llm_build_context::build_k_shift() { ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - const float yarn_attn_factor_shift = model.arch == LLM_ARCH_DEEPSEEK2 + const float yarn_attn_factor_shift = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_MISTRAL4 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; @@ -9891,6 +9891,7 @@ ggml_cgraph * llm_build_context::llama_build_graph( } break; case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: { result = llm.build_deepseek2(); } break; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 20d9694aad..7c093400d4 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -837,11 +837,14 @@ void llm_load_hparams( model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MISTRAL4: case LLM_ARCH_DEEPSEEK2: { + int expected_head_size_k = model.arch == LLM_ARCH_DEEPSEEK2 ? 576 : 320; + int expected_head_size_v = model.arch == LLM_ARCH_DEEPSEEK2 ? 512 : 256; if (hparams.n_head_kv() == 1) { int n_nead_kv = hparams.n_gqa(); - if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 || + if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != expected_head_size_k || hparams.n_embd_head_v != expected_head_size_v || hparams.n_rot != 64) { printf("==========================================================================\n"); printf("Detected incompatible DeepSeek model without a known way to fix it.\n"); @@ -858,7 +861,7 @@ void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v); } - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26) || (hparams.n_layer == 48 && hparams.n_vocab == 128256); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { @@ -888,6 +891,7 @@ void llm_load_hparams( switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; + case 36: model.type = e_model::MODEL_119B_A6B; break; case 47: model.type = e_model::MODEL_30B_A3B; break; // GLM-4.7-Flash case 60: model.type = e_model::MODEL_236B; break; case 61: model.type = e_model::MODEL_671B; break; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 20ddf94398..f950143f99 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -3794,6 +3794,7 @@ bool create_tensors_helper::create_tensors() { case LLM_ARCH_ARCTIC: use_mmap_buffer = create_arctix_tensors(tn); break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: use_mmap_buffer = create_deepseek2_tensors(tn); break; case LLM_ARCH_GLM_DSA: use_mmap_buffer = create_glm_dsa_tensors(tn); break; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d55a1b1568..328220d718 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -930,7 +930,41 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, + { + LLM_ARCH_MISTRAL4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KQ_A_MQA, "blk.%d.attn_kq_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, { @@ -1775,6 +1809,7 @@ const char * llama_model_type_name(e_model type) { case MODEL_80B_A13B: return "80B.A13B"; case MODEL_100B_A6B: return "100B.A6B"; case MODEL_106B_A12B: return "106B.A12B"; + case MODEL_119B_A6B: return "119B.A6B"; case MODEL_122B_A10B: return "122B.A10B"; case MODEL_230B_A10B: return "230B.A10B"; case MODEL_235B_A22B: return "235B.A22B"; diff --git a/src/llama-model.h b/src/llama-model.h index b776df26e2..2e81dbe8ce 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum e_model { MODEL_80B_A13B, MODEL_100B_A6B, MODEL_106B_A12B, + MODEL_119B_A6B, MODEL_122B_A10B, MODEL_230B_A10B, // Minimax M2 MODEL_235B_A22B, diff --git a/src/llama.cpp b/src/llama.cpp index 75347f3385..a2f8408762 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -754,7 +754,7 @@ static bool llama_kv_cache_init( } } - bool is_mla_attn = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA; + bool is_mla_attn = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4; bool split_cache = false; if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && !is_mla_attn && offload) { @@ -1607,7 +1607,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // general kv LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); - if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) { + if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -1652,7 +1652,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } static void llm_prepare_mla(llama_model & model, int mla) { - if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA) return; + if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA && model.arch != LLM_ARCH_MISTRAL4) return; const auto& hparams = model.hparams; const int n_layer = model.layers.size(); int n_to_compute = 0; @@ -2294,7 +2294,7 @@ static bool llm_load_tensors( } } - if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) { + if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4) { llm_prepare_mla(model, mla_attn); } @@ -4905,7 +4905,7 @@ struct llama_context * llama_init_from_model( params.seed = time(NULL); } - if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && cparams.mla_attn != 0) { + if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && model->arch != LLM_ARCH_MISTRAL4 && cparams.mla_attn != 0) { cparams.mla_attn = 0; } if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) { @@ -4928,7 +4928,7 @@ struct llama_context * llama_init_from_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); - if (model->arch == LLM_ARCH_DEEPSEEK2 || model->arch == LLM_ARCH_GLM_DSA) { + if (model->arch == LLM_ARCH_DEEPSEEK2 || model->arch == LLM_ARCH_GLM_DSA || model->arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); } LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); @@ -5384,6 +5384,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_SMOLLM3: case LLM_ARCH_MISTRAL3: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 32691900004a36f5964126370a186a98e46a4af2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 17 Mar 2026 11:55:43 +0000 Subject: [PATCH 2/3] CPU FA --- ggml/src/CMakeLists.txt | 1 + ggml/src/iqk/fa/iqk_fa_320_256.cpp | 124 +++++++++++++++++++++++++++++ ggml/src/iqk/fa/iqk_fa_templates.h | 1 + ggml/src/iqk/iqk_mul_mat.cpp | 4 + 4 files changed, 130 insertions(+) create mode 100644 ggml/src/iqk/fa/iqk_fa_320_256.cpp diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f12f87fc5c..4cf521b1b1 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -264,6 +264,7 @@ if (GGML_IQK_MUL_MAT) set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp iqk/fa/iqk_fa_576_512.cpp + iqk/fa/iqk_fa_320_256.cpp iqk/fa/iqk_fa_192_128.cpp iqk/fa/iqk_fa_192_192.cpp iqk/fa/iqk_fa_256_256.cpp diff --git a/ggml/src/iqk/fa/iqk_fa_320_256.cpp b/ggml/src/iqk/fa/iqk_fa_320_256.cpp new file mode 100644 index 0000000000..95c92ba64c --- /dev/null +++ b/ggml/src/iqk/fa/iqk_fa_320_256.cpp @@ -0,0 +1,124 @@ +#include "iqk/iqk_config.h" + +#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION + +#include "iqk/fa/iqk_fa_templates.h" + +namespace { + +template +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nq1 >= 16) { + int n_step = nq1/16; + FlashAttn<320, 256, 16, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(16*n_step)) return; + } + if (nq1 >= 8) { + int n_step = nq1/8; + FlashAttn<320, 256, 8, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; + } + if (nq1 >= 4) { + int n_step = nq1/4; + FlashAttn<320, 256, 4, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; + } + if (nq1 == 3) { + FlashAttn<320, 256, 3, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 3, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } + else if (nq1 == 2) { + FlashAttn<320, 256, 2, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 2, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } else { + FlashAttn<320, 256, 1, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } +} + +template +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80 kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<320> kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60 kh((const char *)k, stride_k); + HelperQ60 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#if GGML_IQK_FA_ALL_QUANTS + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<320> kh((const char *)k, stride_k); + HelperQ8KV<256> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#endif + if (type_k == GGML_TYPE_F16) { + HelperF16 kh((const char *)k, stride_k); + HelperF16 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<320, step_k> kh((const char *)k, stride_k); + HelperBF16<256, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<320, 256, 8, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } else { + FlashAttnBF16<320, 256, 1, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } + return true; + } +#endif + return false; +} + +} + +IQK_FA_CASE(iqk_fa_320_256) { + + auto type_k = ggml_type(int_type_k); + auto type_v = ggml_type(int_type_v); + + if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) { + return false; + } + stride_q /= sizeof(float); // q stride as float + return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S); + +} + +#endif diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 0bf7557c4e..648e7308f7 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -2238,6 +2238,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, float * qkv, const float * sinkf, float * M, float * S) IQK_FA_CASE(iqk_fa_576_512); +IQK_FA_CASE(iqk_fa_320_256); IQK_FA_CASE(iqk_fa_192_128); IQK_FA_CASE(iqk_fa_192_192); IQK_FA_CASE(iqk_fa_256_256); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2f482524b7..6892c0c29a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1348,6 +1348,10 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } + if (Dk == 320 && Dv == 256) { + return iqk_fa_320_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); + } if (Dk == 192 && Dv == 128) { return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, From cd115cbdd0a00fda96dbb99fa1013759f46384e2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 17 Mar 2026 13:53:30 +0000 Subject: [PATCH 3/3] CUDA FA 320, 256 --- ggml/src/ggml-cuda/fattn-new-mma.cu | 39 ++++++++++++++++++++++++++++- ggml/src/ggml-cuda/fattn.cu | 10 +++++--- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index a13fe5659e..4b933ba01d 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -378,6 +378,38 @@ struct fattn_mma_f16_config<576, 512> { } }; +template <> +struct fattn_mma_f16_config<320, 256> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 160; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } +}; + // ------------------------------------------------------------------------------------------------------------------ // The compiler is always able to unroll loops if they contain continue expressions. @@ -1999,7 +2031,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; - constexpr bool mla = DKQ == 576; + constexpr bool mla = DKQ == 576 || DKQ == 320; const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); @@ -2172,6 +2204,11 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst); return; } + if (Q->ne[0] == 320 && K->ne[0] == 320 && V->ne[0] == 256) { + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 16>(ctx, dst); + return; + } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) { if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 4b982effa1..b534719434 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -119,7 +119,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // so no other implementation works. // - if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + if (new_mma_available(cc) && + ((K->ne[0] == 576 && V->ne[0] == 512) || + (K->ne[0] == 320 && V->ne[0] == 256) || + (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { //printf("Using ggml_cuda_flash_attn_ext_mma_new\n"); ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; @@ -190,8 +193,9 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); } - if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { - if (Q->ne[0] == 576) { + if (new_mma_available(cc) && + (Q->ne[0] == 576 || Q->ne[0] == 320 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + if (Q->ne[0] == 576 || Q->ne[0] == 320) { int gqa_ratio = Q->ne[2]/K->ne[2]; return (gqa_ratio % 4) == 0; }