Skip to content
Merged
119 changes: 94 additions & 25 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

// dedup helpers

static ggml_tensor * build_kq_mask(
static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
Expand All @@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;

return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");

return res;
}

static bool can_reuse_kq_mask(
Expand All @@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(

// impl

static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];

ggml_tensor * res;

res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);

return res;
}

void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
Expand Down Expand Up @@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);

if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}

if (self_v_rot) {
mctx->set_input_v_rot(self_v_rot);
}
}

bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
Expand Down Expand Up @@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);

mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);

if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}

if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
}

bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
Expand Down Expand Up @@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {

mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);

if (inp_attn->self_k_rot) {
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
}

if (inp_attn->self_v_rot) {
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
}

const int64_t n_rs = mctx->get_recr()->get_n_rs();

if (inp_rs->s_copy) {
Expand Down Expand Up @@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}

if (inp_attn->self_k_rot) {
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
}

if (inp_attn->self_v_rot) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}

const int64_t n_rs = mctx->get_recr()->get_n_rs();

if (inp_rs->s_copy) {
Expand Down Expand Up @@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);

ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);

return inp;
}

Expand All @@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);

if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}

if (inp->self_v_rot) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
Expand Down Expand Up @@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
Expand Down Expand Up @@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);

inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

Expand Down Expand Up @@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
}
if (inp->self_v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
}

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
Expand Down Expand Up @@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}

if (wo) {
cur = build_lora_mm(wo, cur);
}
Expand Down Expand Up @@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");

inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}

{
Expand All @@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");

inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}

inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);

return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}

Expand Down Expand Up @@ -2473,19 +2546,15 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);

inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask);

inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}

{
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);

inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask_swa);

inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}

Expand Down
8 changes: 8 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

// note: assumes v_rot^ == I
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;

// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
Expand Down Expand Up @@ -384,6 +388,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

// note: using same rotation matrices for both base and swa cache
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;

const llama_hparams hparams;
const llama_cparams cparams;

Expand Down
Loading
Loading