Skip to content

Commit 042eb8a

Browse files
mtp-batch (wip): merge mtp and model graph
1 parent 1318b2d commit 042eb8a

File tree

6 files changed

+70
-130
lines changed

6 files changed

+70
-130
lines changed

src/llama-context.cpp

Lines changed: 19 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,8 @@ bool llama_context::apply_adapter_cvec(
729729
return cvec.apply(model, data, len, n_embd, il_start, il_end);
730730
}
731731

732-
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
732+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
733+
bool do_mtp_kv_update) {
733734
if (mctx && !mctx->apply()) {
734735
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
735736
ret = GGML_STATUS_FAILED;
@@ -741,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
741742

742743
// the new graph parameters
743744
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
744-
const auto gparams = graph_params(res, ubatch, mctx, gtype);
745+
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update);
745746

746747
if (!graph_reuse_disable && res->can_reuse(gparams)) {
747748
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
@@ -781,7 +782,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
781782
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
782783
}
783784

785+
const int64_t t_exec_start_us = ggml_time_us();
784786
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
787+
const int64_t t_exec_end_us = ggml_time_us();
788+
LLAMA_LOG_INFO(
789+
"[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
790+
(t_exec_end_us - t_exec_start_us) / 1000.0,
791+
ubatch.n_tokens,
792+
do_mtp_kv_update ? "yes" : "no"
793+
);
785794
if (status != GGML_STATUS_SUCCESS) {
786795
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
787796
ret = status;
@@ -850,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
850859
cparams.causal_attn = false;
851860

852861
ggml_status status;
853-
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
862+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false);
854863

855864
cparams.causal_attn = causal_attn_org;
856865

@@ -1092,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10921101
}
10931102

10941103
ggml_status status;
1095-
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1104+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update);
10961105

10971106
if (!res) {
10981107
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1130,39 +1139,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11301139
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
11311140
//}
11321141

1133-
if (do_mtp_kv_update) {
1134-
LLAMA_LOG_INFO(
1135-
"[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n",
1136-
ubatch.n_tokens
1137-
);
1138-
auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
1139-
1140-
auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get());
1141-
ggml_backend_sched_t sched_mtp = params_mtp.sched;
1142-
1143-
auto * gf_mtp = model.build_mtp_graph(params_mtp);
1144-
if (gf_mtp) {
1145-
ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp);
1146-
1147-
ggml_tensor* prev_embedding_tensor = res->get_embd();
1148-
ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input");
1149-
1150-
// ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
1151-
ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp);
1152-
1153-
ggml_backend_sched_graph_compute(sched_mtp, gf_mtp);
1154-
1155-
if (ubatch.output[0]) {
1156-
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
1157-
if (logits_mtp) {
1158-
float * logits_dest = logits + n_outputs_prev * n_vocab;
1159-
ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp));
1160-
}
1161-
}
1162-
}
1163-
ggml_backend_sched_free(sched_mtp);
1164-
}
1165-
11661142
auto * t_logits = res->get_logits();
11671143
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
11681144
embd_tensor = res->get_embd();
@@ -1442,7 +1418,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
14421418

14431419
auto * res = gf_res_reserve.get();
14441420

1445-
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1421+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false);
14461422

14471423
res->reset();
14481424

@@ -1462,8 +1438,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
14621438
llm_graph_params llama_context::graph_params(
14631439
llm_graph_result * res,
14641440
const llama_ubatch & ubatch,
1465-
const llama_memory_context_i * mctx,
1466-
llm_graph_type gtype) const {
1441+
const llama_memory_context_i * mctx,
1442+
llm_graph_type gtype,
1443+
bool update_mtp_kv) const {
14671444
return {
14681445
/*.arch =*/ model.arch,
14691446
/*.hparams =*/ model.hparams,
@@ -1476,36 +1453,13 @@ llm_graph_params llama_context::graph_params(
14761453
/*.loras =*/ &loras,
14771454
/*.mctx =*/ mctx,
14781455
/*.cross =*/ &cross,
1456+
/*.update_mtp_kv =*/ update_mtp_kv,
14791457
/*.n_outputs =*/ n_outputs,
14801458
/*.cb =*/ graph_get_cb(),
14811459
/*.res =*/ res,
14821460
};
14831461
}
14841462

1485-
llm_graph_params llama_context::mtp_graph_params(
1486-
llm_graph_result * res,
1487-
const llama_ubatch& ubatch,
1488-
const llama_memory_context_i * mctx) {
1489-
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
1490-
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
1491-
return {
1492-
/*.arch =*/ model.arch,
1493-
/*.hparams =*/ model.hparams,
1494-
/*.cparams =*/ cparams,
1495-
/*.ubatch =*/ ubatch,
1496-
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
1497-
/*.sched =*/ temp_sched,
1498-
/*.backend_cpu =*/ backend_cpu,
1499-
/*.cvec =*/ &cvec,
1500-
/*.loras =*/ &loras,
1501-
/*.mctx =*/ mctx,
1502-
/*.cross =*/ &cross,
1503-
/*.n_outputs =*/ 1,
1504-
/*.cb =*/ graph_get_cb(temp_sched),
1505-
/*.res =*/ res,
1506-
};
1507-
}
1508-
15091463
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
15101464
const auto& vocab = model.vocab;
15111465
const auto& hparams = model.hparams;
@@ -2240,7 +2194,7 @@ void llama_context::opt_epoch_iter(
22402194

22412195
auto * res = gf_res_prev.get();
22422196

2243-
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2197+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false);
22442198

22452199
res->reset();
22462200

src/llama-context.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ struct llama_context {
9999
const llama_ubatch & ubatch,
100100
llm_graph_type gtype,
101101
llama_memory_context_i * mctx,
102-
ggml_status & ret);
102+
ggml_status & ret,
103+
const bool do_mtp_kv_update);
103104

104105
int encode(const llama_batch & batch_inp);
105106
int decode(const llama_batch & batch_inp);
@@ -200,8 +201,6 @@ struct llama_context {
200201
// reserve a graph with a dummy ubatch of the specified size
201202
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
202203

203-
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
204-
205204
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
206205

207206
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
@@ -213,7 +212,8 @@ struct llama_context {
213212
llm_graph_result * res,
214213
const llama_ubatch & ubatch,
215214
const llama_memory_context_i * mctx,
216-
llm_graph_type gtype) const;
215+
llm_graph_type gtype,
216+
bool update_mtp_kv) const;
217217

218218
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
219219

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ struct llm_graph_params {
402402
const llama_adapter_loras * loras;
403403
const llama_memory_context_i * mctx;
404404
const llama_cross * cross;
405+
bool update_mtp_kv;
405406

406407
uint32_t n_outputs;
407408

src/llama-model.cpp

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13787,7 +13787,8 @@ struct llm_build_glm4 : public llm_graph_context {
1378713787
};
1378813788

1378913789
struct llm_build_glm4_moe : public llm_graph_context {
13790-
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13790+
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path)
13791+
: llm_graph_context(params) {
1379113792
const int64_t n_embd_head = hparams.n_embd_head_v;
1379213793

1379313794
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13932,68 +13933,57 @@ struct llm_build_glm4_moe : public llm_graph_context {
1393213933
cur = inpL;
1393313934
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
1393413935

13935-
cb(cur, "result_norm", -1);
13936+
// cb(cur, "result_norm", -1);
1393613937
res->t_embd = cur;
1393713938

13938-
// lm_head
13939-
cur = build_lora_mm(model.output, cur);
13940-
13941-
cb(cur, "result_output", -1);
13942-
res->t_logits = cur;
13939+
if (build_mtp_path) {
13940+
const int il_mtp = hparams.n_layer - 1;
13941+
const auto & mtp_layer = model.layers[il_mtp];
13942+
13943+
ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head);
13944+
res->t_logits = mtp_logits;
13945+
} else {
13946+
// lm_head
13947+
cur = build_lora_mm(model.output, cur);
13948+
res->t_logits = cur;
13949+
}
1394313950

13944-
ggml_build_forward_expand(gf, cur);
13951+
ggml_build_forward_expand(gf, res->t_logits);
1394513952
}
13946-
};
13947-
13948-
struct llm_build_glm4_moe_mtp : public llm_graph_context {
13949-
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13950-
13951-
const int64_t n_embd_head = hparams.n_embd_head_v;
13952-
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
1395313953

13954+
private:
13955+
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
13956+
int64_t n_embd_head
13957+
) {
1395413958
const int il = hparams.n_layer - 1;
13955-
const auto & mtp_layer = model.layers[il];
1395613959

1395713960
ggml_tensor * inp_pos = build_inp_pos();
1395813961
auto * inp_attn = build_attn_inp_kv_unified();
13959-
13960-
ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens);
13961-
ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input");
13962-
ggml_set_input(prev_embeddings_batch);
13963-
1396413962
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
1396513963

1396613964
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
13967-
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
13968-
13965+
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
13966+
1396913967
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
13970-
1397113968
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
1397213969

1397313970
// now proceed through last layer (skipped in main model)
1397413971
ggml_tensor * inpSA = cur;
13975-
1397613972
// Pre-attention norm for the MTP block
13977-
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
13973+
cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
1397813974

1397913975
// self-attention
1398013976
{
1398113977
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
13982-
if (mtp_layer.bq) {
13983-
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
13984-
}
13978+
if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
1398513979
cb(Qcur, "Qcur", il);
1398613980

1398713981
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
13988-
if (mtp_layer.bk) {
13989-
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
13990-
}
13982+
if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
1399113983
cb(Kcur, "Kcur", il);
1399213984

1399313985
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
13994-
if (mtp_layer.bv) {
13995-
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
13996-
}
13986+
if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
1399713987
cb(Vcur, "Vcur", il);
1399813988

1399913989
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -14025,10 +14015,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1402514015
cb(Qcur, "Qcur", il);
1402614016
cb(Kcur, "Kcur", il);
1402714017
cb(Vcur, "Vcur", il);
14028-
14018+
1402914019
cur = build_attn(inp_attn,
14030-
mtp_layer.wo, NULL,
14031-
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14020+
mtp_layer.wo, NULL,
14021+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
1403214022
}
1403314023

1403414024
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -14068,9 +14058,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1406814058
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
1406914059
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
1407014060

14071-
res->t_logits = cur;
14072-
14073-
ggml_build_forward_expand(gf, res->t_logits);
14061+
return cur;
1407414062
}
1407514063
};
1407614064

@@ -18299,8 +18287,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1829918287
}
1830018288

1830118289
ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
18290+
const int64_t t_start_us = ggml_time_us();
18291+
1830218292
std::unique_ptr<llm_graph_context> llm;
1830318293

18294+
const bool build_mtp = params.update_mtp_kv;
18295+
1830418296
switch (arch) {
1830518297
case LLM_ARCH_LLAMA:
1830618298
{
@@ -18519,7 +18511,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1851918511
} break;
1852018512
case LLM_ARCH_GLM4_MOE:
1852118513
{
18522-
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
18514+
llm = std::make_unique<llm_build_glm4_moe>(*this, params, build_mtp);
1852318515
} break;
1852418516
case LLM_ARCH_BITNET:
1852518517
{
@@ -18660,22 +18652,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1866018652

1866118653
// add on pooling layer
1866218654
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
18663-
18664-
return llm->res->get_gf();
18665-
}
18666-
18667-
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const {
18668-
std::unique_ptr<llm_graph_context> llm;
18669-
18670-
switch (arch) {
18671-
case LLM_ARCH_GLM4_MOE:
18672-
{
18673-
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params);
18674-
} break;
18675-
default:
18676-
GGML_ABORT("fatal error");
18677-
}
18678-
18655+
const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro
18656+
LLAMA_LOG_INFO(
18657+
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
18658+
(t_end_us - t_start_us) / 1000.0,
18659+
build_mtp ? "yes" : "no"
18660+
);
1867918661
return llm->res->get_gf();
1868018662
}
1868118663

src/llama-model.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ struct llama_model {
475475

476476
// TODO: move this to new llm_arch_model_i interface
477477
ggml_cgraph * build_graph(const llm_graph_params & params) const;
478-
ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const;
479478

480479
private:
481480
struct impl;

0 commit comments

Comments
 (0)