Skip to content

Commit f2d315b

Browse files
authored
Avoid rebuild of GGML graph for each token (ikawrakow#98)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable.
1 parent afbf2ef commit f2d315b

4 files changed

Lines changed: 161 additions & 13 deletions

File tree

ggml/include/ggml-backend.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ extern "C" {
232232
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
233233
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
234234

235+
// Utility to query whether cached GGML graph is in use
236+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
237+
238+
// Set whether or not to use GGML graph caching
239+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
240+
235241

236242
#ifdef __cplusplus
237243
}

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,13 @@ extern "C" {
597597
GGML_TENSOR_FLAG_PARAM = 4,
598598
};
599599

600+
// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
601+
enum ggml_kv_cache_flag {
602+
GGML_KV_CACHE_FLAG_NONE = 0,
603+
GGML_KV_CACHE_FLAG_K = 1,
604+
GGML_KV_CACHE_FLAG_V = 2
605+
};
606+
600607
// ggml object
601608
struct ggml_object {
602609
size_t offs;

ggml/src/ggml-backend.c

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,13 @@ struct ggml_backend_sched_split {
10401040
struct ggml_cgraph graph;
10411041
};
10421042

1043+
// Object to facilitate GML graph caching
1044+
struct ggml_cached_graph {
1045+
bool is_active;
1046+
ggml_backend_t input_backend;
1047+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1048+
};
1049+
10431050
struct ggml_backend_sched {
10441051
bool is_reset; // true if the scheduler has been reset since the last graph split
10451052
bool is_alloc;
@@ -1085,6 +1092,8 @@ struct ggml_backend_sched {
10851092
size_t context_buffer_size;
10861093

10871094
bool debug;
1095+
1096+
struct ggml_cached_graph cached_graph;
10881097
};
10891098

10901099
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -1762,6 +1771,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17621771
struct ggml_tensor * input = split->inputs[j];
17631772
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
17641773

1774+
if (!sched->cached_graph.is_active) {
1775+
sched->cached_graph.input_backend = input_backend;
1776+
sched->cached_graph.input_cpy[j] = input_cpy;
1777+
} else {
1778+
input_backend = sched->cached_graph.input_backend;
1779+
input_cpy = sched->cached_graph.input_cpy[j];
1780+
}
1781+
17651782
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17661783
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17671784
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1893,6 +1910,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18931910

18941911
ggml_backend_sched_reset(sched);
18951912

1913+
sched->cached_graph.is_active = false;
1914+
18961915
return sched;
18971916
}
18981917

@@ -1969,16 +1988,16 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19691988
}
19701989

19711990
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1972-
if (!sched->is_reset && !sched->is_alloc) {
1973-
ggml_backend_sched_reset(sched);
1974-
}
1975-
1976-
if (!sched->is_alloc) {
1977-
if (!ggml_backend_sched_alloc_graph(sched, graph)) {
1978-
return GGML_STATUS_ALLOC_FAILED;
1991+
if(!sched->cached_graph.is_active) {
1992+
if (!sched->is_reset && !sched->is_alloc) {
1993+
ggml_backend_sched_reset(sched);
1994+
}
1995+
if (!sched->is_alloc) {
1996+
if (!ggml_backend_sched_alloc_graph(sched, graph)) {
1997+
return GGML_STATUS_ALLOC_FAILED;
1998+
}
19791999
}
19802000
}
1981-
19822001
return ggml_backend_sched_compute_splits(sched);
19832002
}
19842003

@@ -2243,3 +2262,13 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22432262

22442263
return true;
22452264
}
2265+
2266+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2267+
return sched->cached_graph.is_active;
2268+
}
2269+
2270+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2271+
sched->cached_graph.is_active = set_value;
2272+
}
2273+
2274+

src/llama.cpp

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ggml.h"
99
#include "ggml-alloc.h"
1010
#include "ggml-backend.h"
11+
#include "../ggml/src/ggml-impl.h"
1112

1213
#ifdef GGML_USE_RPC
1314
# include "ggml-rpc.h"
@@ -2659,6 +2660,17 @@ struct llama_model {
26592660
}
26602661
};
26612662

2663+
// Object used to allow caching of GGML graph between tokens where possible.
2664+
struct ggml_cached_graph {
2665+
bool is_active = false;
2666+
ggml_cgraph * gf;
2667+
size_t n;
2668+
ggml_backend_t backend_res;
2669+
ggml_backend_t backend_embd;
2670+
struct ggml_tensor * res;
2671+
struct ggml_tensor * embd;
2672+
};
2673+
26622674
struct llama_context {
26632675
llama_context(const llama_model & model)
26642676
: model(model)
@@ -2759,6 +2771,8 @@ struct llama_context {
27592771
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
27602772
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
27612773
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
2774+
2775+
struct ggml_cached_graph cached_graph;
27622776
};
27632777

27642778
struct llama_lora_weight {
@@ -14877,11 +14891,44 @@ static int llama_decode_internal(
1487714891
ggml_backend_sched_reset(lctx.sched);
1487814892
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1487914893

14880-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14894+
ggml_cgraph * gf;
14895+
// the output is always the last tensor in the graph
14896+
struct ggml_tensor * res;
14897+
struct ggml_tensor * embd;
14898+
14899+
bool n_has_changed_since_last_token = false;
14900+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14901+
lctx.cached_graph.n = kv_self.n;
14902+
14903+
// Re-build graph only if graph caching is not possible
14904+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14905+
14906+
gf = llama_build_graph(lctx, u_batch, false);
14907+
14908+
// Set whether GGML graph caching is in use within GGML module, based on
14909+
// whether caching was activated here during the previous token
14910+
ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14911+
14912+
// Disable future graph caching in presence of env var,
14913+
// if there are multiple devices, if batch size is greater than 1,
14914+
// or if nsplits is not 2.
14915+
// TO DO enable graph caching for these cases
14916+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14917+
|| (llama_get_device_count(model) > 1)
14918+
|| (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14919+
for (int i = 0 ; i < gf->n_nodes; i++) {
14920+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14921+
disable_cached_ggml_graph = true;
14922+
break;
14923+
}
14924+
}
14925+
14926+
// Set whether graph caching should be used for future tokens
14927+
lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1488114928

1488214929
// the output is always the last tensor in the graph
14883-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14884-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14930+
res = gf->nodes[gf->n_nodes - 1];
14931+
embd = gf->nodes[gf->n_nodes - 2];
1488514932

1488614933
if (lctx.n_outputs == 0) {
1488714934
// no output
@@ -14901,9 +14948,58 @@ static int llama_decode_internal(
1490114948
embd = nullptr; // do not extract embeddings when not needed
1490214949
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1490314950
}
14951+
lctx.cached_graph.res = res;
14952+
lctx.cached_graph.embd = embd;
1490414953
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1490514954

1490614955
ggml_backend_sched_alloc_graph(lctx.sched, gf);
14956+
}
14957+
else {
14958+
gf = lctx.cached_graph.gf;
14959+
res = lctx.cached_graph.res;
14960+
embd = lctx.cached_graph.embd;
14961+
}
14962+
lctx.cached_graph.gf = gf;
14963+
14964+
// Update K and V cache parameters in cached graph.
14965+
if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14966+
14967+
const struct llama_hparams & hparams = model.hparams;
14968+
const int64_t kv_head = kv_self.head;
14969+
14970+
for (int i = 0; i < gf->n_nodes; i++) {
14971+
ggml_tensor * node = gf->nodes[i];
14972+
if (node->op == GGML_OP_CPY) {
14973+
14974+
// K cache
14975+
const char* k_prefix = "k_cache_view-";
14976+
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14977+
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14978+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
14979+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14980+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14981+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14982+
}
14983+
14984+
// V cache
14985+
const char* v_prefix = "v_cache_view-";
14986+
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14987+
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14988+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14989+
ggml_tensor * tmp_tensor = kv_self.v_l[il];
14990+
size_t tmp_offset;
14991+
if (cparams.flash_attn) {
14992+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14993+
} else {
14994+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14995+
}
14996+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14997+
}
14998+
14999+
}
15000+
}
15001+
15002+
}
1490715003

1490815004
llama_set_inputs(lctx, u_batch);
1490915005

@@ -14927,12 +15023,18 @@ static int llama_decode_internal(
1492715023
// extract logits
1492815024
if (res) {
1492915025
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14930-
GGML_ASSERT(backend_res != nullptr);
14931-
GGML_ASSERT(lctx.logits != nullptr);
1493215026

1493315027
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1493415028
const int32_t n_outputs_new = lctx.n_outputs;
1493515029

15030+
if(!ggml_use_cached_graph(lctx.sched))
15031+
lctx.cached_graph.backend_res = backend_res;
15032+
else
15033+
backend_res = lctx.cached_graph.backend_res;
15034+
15035+
GGML_ASSERT(backend_res != nullptr);
15036+
GGML_ASSERT(lctx.logits != nullptr);
15037+
1493615038
if (n_outputs_new) {
1493715039
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
1493815040
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
@@ -14943,6 +15045,10 @@ static int llama_decode_internal(
1494315045
// extract embeddings
1494415046
if (embd) {
1494515047
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
15048+
if(!ggml_use_cached_graph(lctx.sched))
15049+
lctx.cached_graph.backend_embd = backend_embd;
15050+
else
15051+
backend_embd = lctx.cached_graph.backend_embd;
1494615052
GGML_ASSERT(backend_embd != nullptr);
1494715053

1494815054
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)