Skip to content

Commit 1318b2d

Browse files
mtp-batch (wip): move mtp execution to batch format
1 parent c6237c7 commit 1318b2d

File tree

8 files changed

+166
-129
lines changed

8 files changed

+166
-129
lines changed

common/speculative.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -374,47 +374,54 @@ llama_token mtp_speculative_gen_draft(
374374
return -1;
375375
}
376376

377-
llama_batch batch = llama_batch_init(1, 0, 1);
378-
common_batch_add(batch, id_last, n_past, {0}, true);
377+
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
378+
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
379+
mtp_batch.update_mtp_kv = true;
379380

380-
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
381+
llama_decode(ctx, mtp_batch);
382+
llama_batch_free(mtp_batch);
381383

382384
const llama_model * model = llama_get_model(ctx);
383385
const llama_vocab * vocab = llama_model_get_vocab(model);
384386
const int n_vocab = llama_n_vocab(vocab);
385-
386387
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
387-
388388
cur_p->size = n_vocab;
389389
for (int i = 0; i < n_vocab; ++i) {
390390
cur_p->data[i].id = i;
391-
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
391+
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right
392392
}
393393
cur_p->sorted = false;
394-
395394
common_sampler_apply_chain(smpl, cur_p);
396-
397-
const llama_token id = cur_p->data[0].id;
398-
399-
llama_batch_free(batch);
400-
401-
return id;
395+
396+
return cur_p->data[0].id;
402397
}
403398

404399

405400
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
406-
mtp_kv_update_data token;
407-
401+
if (tokens.empty()) {
402+
tokens.clear();
403+
return;
404+
}
408405
if (n_tokens < 0) {
409406
n_tokens = tokens.size();
410407
}
408+
const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens);
409+
410+
LOG_DBG(
411+
"[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n",
412+
n_to_process
413+
);
414+
llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1);
415+
416+
for (size_t i = 0; i < n_to_process; ++i) {
417+
const mtp_kv_update_data& token_data = tokens[i];
418+
common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false);
419+
}
411420

412-
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
413-
token = tokens[i];
414-
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
421+
mtp_batch.update_mtp_kv = true;
415422

416-
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
417-
}
423+
llama_decode(ctx, mtp_batch);
418424

425+
llama_batch_free(mtp_batch);
419426
tokens.clear();
420427
}

include/llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ extern "C" {
230230
int32_t * n_seq_id;
231231
llama_seq_id ** seq_id;
232232
int8_t * logits; // TODO: rename this to "output"
233+
bool update_mtp_kv;
233234
} llama_batch;
234235

235236
enum llama_model_kv_override_type {
@@ -1454,8 +1455,8 @@ extern "C" {
14541455
ggml_opt_epoch_callback callback_train,
14551456
ggml_opt_epoch_callback callback_eval);
14561457

1457-
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1458-
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
1458+
// LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1459+
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
14591460

14601461
#ifdef __cplusplus
14611462
}

src/llama-batch.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one(
834834

835835
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
836836
llama_batch batch = {
837-
/*n_tokens =*/ 0,
838-
/*tokens =*/ nullptr,
839-
/*embd =*/ nullptr,
840-
/*pos =*/ nullptr,
841-
/*n_seq_id =*/ nullptr,
842-
/*seq_id =*/ nullptr,
843-
/*logits =*/ nullptr,
837+
/*n_tokens =*/ 0,
838+
/*tokens =*/ nullptr,
839+
/*embd =*/ nullptr,
840+
/*pos =*/ nullptr,
841+
/*n_seq_id =*/ nullptr,
842+
/*seq_id =*/ nullptr,
843+
/*logits =*/ nullptr,
844+
/*update_mtp_kv =*/ false,
844845
};
845846

846847
if (embd) {

src/llama-context.cpp

Lines changed: 93 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10701070
};
10711071

10721072
int64_t n_outputs_prev = 0;
1073+
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
10731074

10741075
do {
10751076
const auto & ubatch = mctx->get_ubatch();
@@ -1129,6 +1130,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
11291130
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
11301131
//}
11311132

1133+
if (do_mtp_kv_update) {
1134+
LLAMA_LOG_INFO(
1135+
"[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n",
1136+
ubatch.n_tokens
1137+
);
1138+
auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
1139+
1140+
auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get());
1141+
ggml_backend_sched_t sched_mtp = params_mtp.sched;
1142+
1143+
auto * gf_mtp = model.build_mtp_graph(params_mtp);
1144+
if (gf_mtp) {
1145+
ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp);
1146+
1147+
ggml_tensor* prev_embedding_tensor = res->get_embd();
1148+
ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input");
1149+
1150+
// ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
1151+
ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp);
1152+
1153+
ggml_backend_sched_graph_compute(sched_mtp, gf_mtp);
1154+
1155+
if (ubatch.output[0]) {
1156+
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
1157+
if (logits_mtp) {
1158+
float * logits_dest = logits + n_outputs_prev * n_vocab;
1159+
ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp));
1160+
}
1161+
}
1162+
}
1163+
ggml_backend_sched_free(sched_mtp);
1164+
}
1165+
11321166
auto * t_logits = res->get_logits();
11331167
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
11341168
embd_tensor = res->get_embd();
@@ -2995,79 +3029,79 @@ void llama_opt_epoch(
29953029
callback_eval);
29963030
}
29973031

2998-
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
2999-
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
3032+
// void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
3033+
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
30003034

3001-
const auto * model = llama_get_model(ctx);
3035+
// const auto * model = llama_get_model(ctx);
30023036

3003-
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3004-
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
3037+
// auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3038+
// std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
30053039

3006-
std::vector<uint32_t> idxs;
3007-
idxs.push_back(n_past);
3008-
llama_kv_cache_unified::slot_info sinfo = {
3009-
/*.s0 =*/ 0,
3010-
/*.s1 =*/ 0,
3011-
/*.strm =*/ { 0 },
3012-
/*.idxs =*/ { idxs },
3013-
};
3014-
llama_kv_cache_unified::slot_info_vec_t sinfos;
3015-
sinfos.push_back(sinfo);
3040+
// std::vector<uint32_t> idxs;
3041+
// idxs.push_back(n_past);
3042+
// llama_kv_cache_unified::slot_info sinfo = {
3043+
// /*.s0 =*/ 0,
3044+
// /*.s1 =*/ 0,
3045+
// /*.strm =*/ { 0 },
3046+
// /*.idxs =*/ { idxs },
3047+
// };
3048+
// llama_kv_cache_unified::slot_info_vec_t sinfos;
3049+
// sinfos.push_back(sinfo);
30163050

3017-
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
3018-
const auto& ubatch_mtp = mctx->get_ubatch();
3051+
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
3052+
// const auto& ubatch_mtp = mctx->get_ubatch();
30193053

3020-
//llama_ubatch ubatch_mtp;
3021-
//ubatch_mtp.n_tokens = 1;
3022-
//ubatch_mtp.pos = &n_past;
3054+
// //llama_ubatch ubatch_mtp;
3055+
// //ubatch_mtp.n_tokens = 1;
3056+
// //ubatch_mtp.pos = &n_past;
30233057

3024-
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
3025-
ggml_backend_sched_t sched = params_mtp->sched;
3058+
// auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
3059+
// ggml_backend_sched_t sched = params_mtp->sched;
30263060

3027-
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
3061+
// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
30283062

3029-
//if (mctx && !mctx->set_n_kv()) {
3030-
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3031-
//}
3032-
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
3063+
// //if (mctx && !mctx->set_n_kv()) {
3064+
// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3065+
// //}
3066+
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
30333067

3034-
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
3068+
// auto * gf = model->build_mtp_graph(*params_mtp);
30353069

3036-
if (!gf) {
3037-
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
3038-
if (sched) ggml_backend_sched_free(sched);
3039-
return;
3040-
}
3070+
// if (!gf) {
3071+
// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
3072+
// if (sched) ggml_backend_sched_free(sched);
3073+
// return;
3074+
// }
30413075

3042-
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
3043-
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
3076+
// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
3077+
// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
30443078

3045-
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
3046-
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
3079+
// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
3080+
// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
30473081

3048-
ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
3049-
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
3082+
// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
3083+
// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
30503084

3051-
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
3085+
// ggml_backend_sched_graph_compute(sched, gf); // execute the graph
30523086

3053-
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
3087+
// struct ggml_tensor * logits_mtp = res_mtp->get_logits();
30543088

3055-
if (logits_mtp) {
3056-
float * logits_dest = ctx->get_logits_ith(last_tok_idx);
3057-
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
3058-
if (backend_res) {
3059-
// ggml_backend_tensor_get is the function for GPU->CPU copies.
3060-
// We are copying a single 32-bit integer.
3061-
ggml_backend_tensor_get(logits_mtp,
3062-
logits_dest, // Pointer to our C++ variable
3063-
0, // Starting offset in bytes
3064-
ggml_nbytes(logits_mtp)); // Number of bytes to copy
3065-
} else {
3066-
LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
3067-
}
3068-
} else {
3069-
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
3070-
}
3089+
// if (logits_mtp) {
3090+
// float * logits_dest = ctx->get_logits_ith(last_tok_idx);
3091+
// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
3092+
// if (backend_res) {
3093+
// // ggml_backend_tensor_get is the function for GPU->CPU copies.
3094+
// // We are copying a single 32-bit integer.
3095+
// ggml_backend_tensor_get(logits_mtp,
3096+
// logits_dest, // Pointer to our C++ variable
3097+
// 0, // Starting offset in bytes
3098+
// ggml_nbytes(logits_mtp)); // Number of bytes to copy
3099+
// } else {
3100+
// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
3101+
// }
3102+
// } else {
3103+
// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
3104+
// }
30713105

3072-
ggml_backend_sched_free(sched);
3073-
}
3106+
// ggml_backend_sched_free(sched);
3107+
// }

src/llama-graph.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
10741074
return cur;
10751075
}
10761076

1077+
1078+
ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const {
1079+
auto inp = std::make_unique<llm_graph_input_embd>();
1080+
ggml_tensor * cur = nullptr;
1081+
1082+
if (ubatch.token) {
1083+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1084+
ggml_set_name(inp->tokens, "mtp_inp_tokens");
1085+
ggml_set_input(inp->tokens);
1086+
1087+
cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens);
1088+
} else {
1089+
GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings");
1090+
}
1091+
1092+
cb(cur, "mtp_inp_embd", -1);
1093+
res->add_input(std::move(inp));
1094+
return cur;
1095+
}
1096+
10771097
ggml_tensor * llm_graph_context::build_inp_pos() const {
10781098
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
10791099

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ struct llm_graph_context {
664664
//
665665

666666
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
667+
ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const;
667668
ggml_tensor * build_inp_pos() const;
668669
ggml_tensor * build_inp_attn_scale() const;
669670
ggml_tensor * build_inp_out_ids() const;

0 commit comments

Comments
 (0)