Skip to content

Commit 1c5724e

Browse files
committed
kv-cache : support V-less cache
1 parent 080b161 commit 1c5724e

File tree

4 files changed

+189
-17
lines changed

4 files changed

+189
-17
lines changed

src/llama-graph.cpp

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
407407
return res;
408408
}
409409

410+
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
411+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
412+
413+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
414+
}
415+
416+
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
417+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
418+
419+
this->mctx = mctx;
420+
421+
bool res = true;
422+
423+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
424+
425+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
426+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
427+
428+
return res;
429+
}
430+
410431
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
411432
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
412433
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
15961617
v = ggml_transpose(ctx0, v);
15971618
}
15981619

1599-
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
1600-
if (v_mla) {
1601-
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
1602-
}
1603-
16041620
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
16051621
if (k->type == GGML_TYPE_F32) {
16061622
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
@@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
18231839
ggml_tensor * v_cur,
18241840
ggml_tensor * kq_b,
18251841
ggml_tensor * sinks,
1826-
ggml_tensor * v_mla,
1842+
ggml_tensor * v_mla, // TODO: remove
18271843
float kq_scale,
18281844
int il) const {
1845+
GGML_ASSERT(v_mla == nullptr);
1846+
18291847
// these nodes are added to the graph together so that they are not reordered
18301848
// by doing so, the number of splits in the graph is reduced
18311849
// expand k later to enable rope fusion which directly writes into k-v cache
@@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
18681886
return cur;
18691887
}
18701888

1889+
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
1890+
ggml_context * ctx0,
1891+
const llama_ubatch & ubatch,
1892+
const llama_hparams & hparams,
1893+
const llama_cparams & cparams,
1894+
const llama_kv_cache_context * mctx_cur) {
1895+
1896+
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
1897+
1898+
{
1899+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1900+
1901+
const auto n_kv = mctx_cur->get_n_kv();
1902+
const auto n_tokens = ubatch.n_tokens;
1903+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1904+
1905+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1906+
1907+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1908+
ggml_set_input(inp->self_kq_mask);
1909+
1910+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1911+
}
1912+
1913+
return inp;
1914+
}
1915+
1916+
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
1917+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1918+
1919+
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1920+
1921+
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
1922+
}
1923+
1924+
ggml_tensor * llm_graph_context::build_attn(
1925+
llm_graph_input_attn_k * inp,
1926+
ggml_tensor * wo,
1927+
ggml_tensor * wo_b,
1928+
ggml_tensor * q_cur,
1929+
ggml_tensor * k_cur,
1930+
ggml_tensor * v_cur,
1931+
ggml_tensor * kq_b,
1932+
ggml_tensor * sinks,
1933+
ggml_tensor * v_mla,
1934+
float kq_scale,
1935+
int il) const {
1936+
// these nodes are added to the graph together so that they are not reordered
1937+
// by doing so, the number of splits in the graph is reduced
1938+
// expand k later to enable rope fusion which directly writes into k-v cache
1939+
ggml_build_forward_expand(gf, q_cur);
1940+
ggml_build_forward_expand(gf, v_cur);
1941+
ggml_build_forward_expand(gf, k_cur);
1942+
1943+
const auto * mctx_cur = inp->mctx;
1944+
1945+
// store to KV cache
1946+
{
1947+
const auto & k_idxs = inp->get_k_idxs();
1948+
1949+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1950+
}
1951+
1952+
const auto & kq_mask = inp->get_kq_mask();
1953+
1954+
ggml_tensor * q = q_cur;
1955+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1956+
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
1957+
1958+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1959+
cb(cur, "kqv_out", il);
1960+
1961+
if (wo) {
1962+
cur = build_lora_mm(wo, cur);
1963+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1964+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1965+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1966+
}
1967+
}
1968+
1969+
if (wo_b) {
1970+
cur = ggml_add(ctx0, cur, wo_b);
1971+
}
1972+
1973+
return cur;
1974+
}
1975+
18711976
ggml_tensor * llm_graph_context::build_attn(
18721977
llm_graph_input_attn_kv_iswa * inp,
18731978
ggml_tensor * wo,

src/llama-graph.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,37 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
317317
const llama_kv_cache_context * mctx;
318318
};
319319

320+
class llm_graph_input_attn_k : public llm_graph_input_i {
321+
public:
322+
llm_graph_input_attn_k(
323+
const llama_hparams & hparams,
324+
const llama_cparams & cparams,
325+
const llama_kv_cache_context * mctx) :
326+
hparams(hparams),
327+
cparams(cparams),
328+
mctx(mctx) {
329+
}
330+
~llm_graph_input_attn_k() = default;
331+
332+
void set_input(const llama_ubatch * ubatch) override;
333+
334+
bool can_reuse(const llm_graph_params & params) override;
335+
336+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
337+
338+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
339+
340+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
341+
342+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
343+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
344+
345+
const llama_hparams hparams;
346+
const llama_cparams cparams;
347+
348+
const llama_kv_cache_context * mctx;
349+
};
350+
320351
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
321352
public:
322353
llm_graph_input_attn_kv_iswa(
@@ -833,6 +864,21 @@ struct llm_graph_context {
833864
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
834865
ggml_tensor * kq_b,
835866
ggml_tensor * sinks, // [n_head_q]
867+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
868+
float kq_scale,
869+
int il) const;
870+
871+
llm_graph_input_attn_k * build_attn_inp_k () const;
872+
873+
ggml_tensor * build_attn(
874+
llm_graph_input_attn_k * inp,
875+
ggml_tensor * wo,
876+
ggml_tensor * wo_b,
877+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
878+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
879+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
880+
ggml_tensor * kq_b,
881+
ggml_tensor * sinks, // [n_head_q]
836882
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
837883
float kq_scale,
838884
int il) const;

src/llama-kv-cache.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache(
9797
__func__, hparams.n_embd_v_gqa_max());
9898
}
9999

100+
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
101+
100102
for (uint32_t il = 0; il < hparams.n_layer; il++) {
101103
if (!hparams.has_kv(il)) {
102104
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache(
130132
throw std::runtime_error("failed to create ggml context for kv cache");
131133
}
132134

133-
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
134-
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
135+
const bool has_k = true;
136+
const bool has_v = !is_mla;
137+
138+
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
139+
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
135140

136-
ggml_format_name(k, "cache_k_l%d", il);
137-
ggml_format_name(v, "cache_v_l%d", il);
141+
has_k && ggml_format_name(k, "cache_k_l%d", il);
142+
has_v && ggml_format_name(v, "cache_v_l%d", il);
138143

139144
std::vector<ggml_tensor *> k_stream;
140145
std::vector<ggml_tensor *> v_stream;
141146

142147
for (uint32_t s = 0; s < n_stream; ++s) {
143-
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
144-
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
148+
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
149+
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
145150
}
146151

147152
map_layer_ids[il] = layers.size();
@@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
647652
const auto & layer = layers[il];
648653

649654
ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
650-
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
655+
656+
if (layer.v_stream[ssrc]) {
657+
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
658+
}
651659
}
652660
}
653661
}
@@ -1516,7 +1524,7 @@ size_t llama_kv_cache::size_v_bytes() const {
15161524
size_t size_v_bytes = 0;
15171525

15181526
for (const auto & layer : layers) {
1519-
size_v_bytes += ggml_nbytes(layer.v);
1527+
size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
15201528
}
15211529

15221530
return size_v_bytes;
@@ -1798,6 +1806,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
17981806
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
17991807

18001808
auto * v = layer.v_stream[cr.strm];
1809+
if (!v) {
1810+
continue;
1811+
}
18011812

18021813
// Write value type
18031814
const int32_t v_type_i = (int32_t) v->type;
@@ -1824,6 +1835,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
18241835
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
18251836

18261837
auto * v = layer.v_stream[cr.strm];
1838+
if (!v) {
1839+
continue;
1840+
}
18271841

18281842
// Write value type
18291843
const int32_t v_type_i = (int32_t) v->type;
@@ -2027,6 +2041,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
20272041
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
20282042

20292043
auto * v = layer.v_stream[strm];
2044+
if (!v) {
2045+
continue;
2046+
}
20302047

20312048
// Read type of value
20322049
int32_t v_type_i_ref;
@@ -2068,6 +2085,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
20682085
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
20692086

20702087
auto * v = layer.v_stream[strm];
2088+
if (!v) {
2089+
continue;
2090+
}
20712091

20722092
// Read type of value
20732093
int32_t v_type_i_ref;

src/models/deepseek2.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
4343
// inp_pos - contains the positions
4444
ggml_tensor * inp_pos = build_inp_pos();
4545

46-
auto * inp_attn = build_attn_inp_kv();
46+
auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
47+
auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr;
4748

4849
ggml_tensor * inp_out_ids = build_inp_out_ids();
4950

@@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
145146
}
146147

147148
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
148-
cur = build_attn(inp_attn,
149+
cur = build_attn(inp_attn_k,
149150
model.layers[il].wo, NULL,
150151
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
151152
} else {
@@ -182,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
182183
}
183184

184185
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
185-
cur = build_attn(inp_attn,
186+
cur = build_attn(inp_attn_kv,
186187
model.layers[il].wo, NULL,
187188
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
188189
}

0 commit comments

Comments
 (0)