Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ if (GGML_IQK_MUL_MAT)
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
iqk/iqk_flash_attn.cpp
iqk/fa/iqk_fa_576_512.cpp
iqk/fa/iqk_fa_320_256.cpp
iqk/fa/iqk_fa_192_128.cpp
iqk/fa/iqk_fa_192_192.cpp
iqk/fa/iqk_fa_256_256.cpp
Expand Down
39 changes: 38 additions & 1 deletion ggml/src/ggml-cuda/fattn-new-mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,38 @@ struct fattn_mma_f16_config<576, 512> {
}
};

template <>
struct fattn_mma_f16_config<320, 256> {
static constexpr int nbatch_fa = 32;
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;

static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 160;
}

static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 160;
}

static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}

static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 128;
}

static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}

static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 128;
}
};

// ------------------------------------------------------------------------------------------------------------------

// The compiler is always able to unroll loops if they contain continue expressions.
Expand Down Expand Up @@ -1999,7 +2031,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;

constexpr bool mla = DKQ == 576;
constexpr bool mla = DKQ == 576 || DKQ == 320;

const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
Expand Down Expand Up @@ -2172,6 +2204,11 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
return;
}
if (Q->ne[0] == 320 && K->ne[0] == 320 && V->ne[0] == 256) {
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 16>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) {
if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) {
Expand Down
10 changes: 7 additions & 3 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// so no other implementation works.
//

if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (new_mma_available(cc) &&
((K->ne[0] == 576 && V->ne[0] == 512) ||
(K->ne[0] == 320 && V->ne[0] == 256) ||
(K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
Expand Down Expand Up @@ -190,8 +193,9 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
}

if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576) {
if (new_mma_available(cc) &&
(Q->ne[0] == 576 || Q->ne[0] == 320 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576 || Q->ne[0] == 320) {
int gqa_ratio = Q->ne[2]/K->ne[2];
return (gqa_ratio % 4) == 0;
}
Expand Down
124 changes: 124 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_320_256.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

namespace {

template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
q += n*stride_q;
mask += n*stride_m;
qkv += n*stride_qkv;
if (M && S) { M += n; S += n; }
return false;
};
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<320, 256, 16, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(16*n_step)) return;
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<320, 256, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return;
}
if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<320, 256, 4, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
if (nq1 == 3) {
FlashAttn<320, 256, 3, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 3, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
else if (nq1 == 2) {
FlashAttn<320, 256, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
} else {
FlashAttn<320, 256, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
}

template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80 kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q8_0_R8) {
HelperQ80R8<320> kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60 kh((const char *)k, stride_k);
HelperQ60 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#if GGML_IQK_FA_ALL_QUANTS
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<320> kh((const char *)k, stride_k);
HelperQ8KV<256> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (type_k == GGML_TYPE_F16) {
HelperF16 kh((const char *)k, stride_k);
HelperF16 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
HelperBF16<320, step_k> kh((const char *)k, stride_k);
HelperBF16<256, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<320, 256, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<320, 256, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
return true;
}
#endif
return false;
}

}

IQK_FA_CASE(iqk_fa_320_256) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
return false;
}
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);

}

#endif
1 change: 1 addition & 0 deletions ggml/src/iqk/fa/iqk_fa_templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
float * qkv, const float * sinkf, float * M, float * S)

IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_320_256);
IQK_FA_CASE(iqk_fa_192_128);
IQK_FA_CASE(iqk_fa_192_192);
IQK_FA_CASE(iqk_fa_256_256);
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,10 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 320 && Dv == 256) {
return iqk_fa_320_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}

if (Dk == 192 && Dv == 128) {
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_STEP35, "step35" },
{ LLM_ARCH_GLM_DSA, "glm-dsa" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ enum llm_arch {
LLM_ARCH_SEED_OSS,
LLM_ARCH_STEP35,
LLM_ARCH_GLM_DSA,
LLM_ARCH_MISTRAL4,
LLM_ARCH_UNKNOWN,
};

Expand Down
3 changes: 2 additions & 1 deletion src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ ggml_cgraph * llm_build_context::build_k_shift() {
? LLAMA_ROPE_TYPE_NEOX
: hparams.rope_type;

const float yarn_attn_factor_shift = model.arch == LLM_ARCH_DEEPSEEK2
const float yarn_attn_factor_shift = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_MISTRAL4
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
: cparams.yarn_attn_factor;

Expand Down Expand Up @@ -9891,6 +9891,7 @@ ggml_cgraph * llm_build_context::llama_build_graph(
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_GLM_DSA:
case LLM_ARCH_MISTRAL4:
{
result = llm.build_deepseek2();
} break;
Expand Down
8 changes: 6 additions & 2 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,11 +837,14 @@ void llm_load_hparams(
model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MISTRAL4:
case LLM_ARCH_DEEPSEEK2:
{
int expected_head_size_k = model.arch == LLM_ARCH_DEEPSEEK2 ? 576 : 320;
int expected_head_size_v = model.arch == LLM_ARCH_DEEPSEEK2 ? 512 : 256;
if (hparams.n_head_kv() == 1) {
int n_nead_kv = hparams.n_gqa();
if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 ||
if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != expected_head_size_k || hparams.n_embd_head_v != expected_head_size_v ||
hparams.n_rot != 64) {
printf("==========================================================================\n");
printf("Detected incompatible DeepSeek model without a known way to fix it.\n");
Expand All @@ -858,7 +861,7 @@ void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v);
}
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26) || (hparams.n_layer == 48 && hparams.n_vocab == 128256);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
if (!is_lite) {
Expand Down Expand Up @@ -888,6 +891,7 @@ void llm_load_hparams(

switch (hparams.n_layer) {
case 27: model.type = e_model::MODEL_16B; break;
case 36: model.type = e_model::MODEL_119B_A6B; break;
case 47: model.type = e_model::MODEL_30B_A3B; break; // GLM-4.7-Flash
case 60: model.type = e_model::MODEL_236B; break;
case 61: model.type = e_model::MODEL_671B; break;
Expand Down
1 change: 1 addition & 0 deletions src/llama-load-tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3794,6 +3794,7 @@ bool create_tensors_helper::create_tensors() {
case LLM_ARCH_ARCTIC:
use_mmap_buffer = create_arctix_tensors(tn); break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_MISTRAL4:
use_mmap_buffer = create_deepseek2_tensors(tn); break;
case LLM_ARCH_GLM_DSA:
use_mmap_buffer = create_glm_dsa_tensors(tn); break;
Expand Down
37 changes: 36 additions & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,41 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_MISTRAL4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KQ_A_MQA, "blk.%d.attn_kq_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
Expand Down Expand Up @@ -1775,6 +1809,7 @@ const char * llama_model_type_name(e_model type) {
case MODEL_80B_A13B: return "80B.A13B";
case MODEL_100B_A6B: return "100B.A6B";
case MODEL_106B_A12B: return "106B.A12B";
case MODEL_119B_A6B: return "119B.A6B";
case MODEL_122B_A10B: return "122B.A10B";
case MODEL_230B_A10B: return "230B.A10B";
case MODEL_235B_A22B: return "235B.A22B";
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ enum e_model {
MODEL_80B_A13B,
MODEL_100B_A6B,
MODEL_106B_A12B,
MODEL_119B_A6B,
MODEL_122B_A10B,
MODEL_230B_A10B, // Minimax M2
MODEL_235B_A22B,
Expand Down
Loading