@@ -1070,6 +1070,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10701070 };
10711071
10721072 int64_t n_outputs_prev = 0 ;
1073+ const bool do_mtp_kv_update = batch_inp.update_mtp_kv ;
10731074
10741075 do {
10751076 const auto & ubatch = mctx->get_ubatch ();
@@ -1129,6 +1130,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
11291130 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
11301131 // }
11311132
1133+ if (do_mtp_kv_update) {
1134+ LLAMA_LOG_INFO (
1135+ " [MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n " ,
1136+ ubatch.n_tokens
1137+ );
1138+ auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes ());
1139+
1140+ auto params_mtp = mtp_graph_params (res_mtp.get (), ubatch, mctx.get ());
1141+ ggml_backend_sched_t sched_mtp = params_mtp.sched ;
1142+
1143+ auto * gf_mtp = model.build_mtp_graph (params_mtp);
1144+ if (gf_mtp) {
1145+ ggml_backend_sched_alloc_graph (sched_mtp, gf_mtp);
1146+
1147+ ggml_tensor* prev_embedding_tensor = res->get_embd ();
1148+ ggml_tensor* embd_input_mtp = ggml_get_tensor (res_mtp->get_ctx (), " mtp_prev_embeddings_batch_input" );
1149+
1150+ // ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
1151+ ggml_backend_tensor_copy (prev_embedding_tensor, embd_input_mtp);
1152+
1153+ ggml_backend_sched_graph_compute (sched_mtp, gf_mtp);
1154+
1155+ if (ubatch.output [0 ]) {
1156+ struct ggml_tensor * logits_mtp = res_mtp->get_logits ();
1157+ if (logits_mtp) {
1158+ float * logits_dest = logits + n_outputs_prev * n_vocab;
1159+ ggml_backend_tensor_get (logits_mtp, logits_dest, 0 , ggml_nbytes (logits_mtp));
1160+ }
1161+ }
1162+ }
1163+ ggml_backend_sched_free (sched_mtp);
1164+ }
1165+
11321166 auto * t_logits = res->get_logits ();
11331167 auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
11341168 embd_tensor = res->get_embd ();
@@ -2995,79 +3029,79 @@ void llama_opt_epoch(
29953029 callback_eval);
29963030}
29973031
2998- void llama_build_and_execute_mtp_graph (struct llama_context * ctx,
2999- const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
3032+ // void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
3033+ // const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
30003034
3001- const auto * model = llama_get_model (ctx);
3035+ // const auto * model = llama_get_model(ctx);
30023036
3003- auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes ());
3004- std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch (batch_inp);
3037+ // auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3038+ // std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
30053039
3006- std::vector<uint32_t > idxs;
3007- idxs.push_back (n_past);
3008- llama_kv_cache_unified::slot_info sinfo = {
3009- /* .s0 =*/ 0 ,
3010- /* .s1 =*/ 0 ,
3011- /* .strm =*/ { 0 },
3012- /* .idxs =*/ { idxs },
3013- };
3014- llama_kv_cache_unified::slot_info_vec_t sinfos;
3015- sinfos.push_back (sinfo);
3040+ // std::vector<uint32_t> idxs;
3041+ // idxs.push_back(n_past);
3042+ // llama_kv_cache_unified::slot_info sinfo = {
3043+ // /*.s0 =*/ 0,
3044+ // /*.s1 =*/ 0,
3045+ // /*.strm =*/ { 0 },
3046+ // /*.idxs =*/ { idxs },
3047+ // };
3048+ // llama_kv_cache_unified::slot_info_vec_t sinfos;
3049+ // sinfos.push_back(sinfo);
30163050
3017- static_cast <llama_kv_cache_unified_context*>(mctx.get ())->set_sinfos (sinfos);
3018- const auto & ubatch_mtp = mctx->get_ubatch ();
3051+ // static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
3052+ // const auto& ubatch_mtp = mctx->get_ubatch();
30193053
3020- // llama_ubatch ubatch_mtp;
3021- // ubatch_mtp.n_tokens = 1;
3022- // ubatch_mtp.pos = &n_past;
3054+ // //llama_ubatch ubatch_mtp;
3055+ // //ubatch_mtp.n_tokens = 1;
3056+ // //ubatch_mtp.pos = &n_past;
30233057
3024- auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params (res_mtp.get (), ubatch_mtp, mctx.get ()));
3025- ggml_backend_sched_t sched = params_mtp->sched ;
3058+ // auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
3059+ // ggml_backend_sched_t sched = params_mtp->sched;
30263060
3027- auto * last_embd = ctx->get_embeddings_ith (last_tok_idx);
3061+ // auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
30283062
3029- // if (mctx && !mctx->set_n_kv()) {
3030- // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3031- // }
3032- static_cast <llama_kv_cache_unified_context*>(mctx.get ())->set_n_kv ();
3063+ // //if (mctx && !mctx->set_n_kv()) {
3064+ // // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3065+ // //}
3066+ // static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
30333067
3034- auto * gf = model->build_mtp_graph (*params_mtp, last_token_id, n_past );
3068+ // auto * gf = model->build_mtp_graph(*params_mtp);
30353069
3036- if (!gf) {
3037- LLAMA_LOG_ERROR (" %s: ERROR - The construction of the MTP graph failed (returned null)." , __func__);
3038- if (sched) ggml_backend_sched_free (sched);
3039- return ;
3040- }
3070+ // if (!gf) {
3071+ // LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
3072+ // if (sched) ggml_backend_sched_free(sched);
3073+ // return;
3074+ // }
30413075
3042- ggml_backend_sched_reset (sched); // clear the allocation of the previous graph
3043- ggml_backend_sched_alloc_graph (sched, gf); // explicitly allocate the new graph but do not execute it
3076+ // ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
3077+ // ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
30443078
3045- ggml_tensor * mtp_token_id_input = ggml_get_tensor (res_mtp->get_ctx (), " mtp_token_id_input" );
3046- ggml_backend_tensor_set (mtp_token_id_input, &last_token_id, 0 , sizeof (last_token_id)); // copy data to the newly allocated graph tensors
3079+ // ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
3080+ // ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
30473081
3048- ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor (res_mtp->get_ctx (), " mtp_prev_embedding_input" );
3049- ggml_backend_tensor_set (mtp_prev_embedding_input, last_embd, 0 , ggml_nbytes (mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
3082+ // ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
3083+ // ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
30503084
3051- ggml_backend_sched_graph_compute (sched, gf); // execute the graph
3085+ // ggml_backend_sched_graph_compute(sched, gf); // execute the graph
30523086
3053- struct ggml_tensor * logits_mtp = res_mtp->get_logits ();
3087+ // struct ggml_tensor * logits_mtp = res_mtp->get_logits();
30543088
3055- if (logits_mtp) {
3056- float * logits_dest = ctx->get_logits_ith (last_tok_idx);
3057- ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched, logits_mtp);
3058- if (backend_res) {
3059- // ggml_backend_tensor_get is the function for GPU->CPU copies.
3060- // We are copying a single 32-bit integer.
3061- ggml_backend_tensor_get (logits_mtp,
3062- logits_dest, // Pointer to our C++ variable
3063- 0 , // Starting offset in bytes
3064- ggml_nbytes (logits_mtp)); // Number of bytes to copy
3065- } else {
3066- LLAMA_LOG_ERROR (" %s: ERROR - Could not obtain the backend for the logits tensor." , __func__);
3067- }
3068- } else {
3069- LLAMA_LOG_WARN (" %s: WARNING - The MTP graph did not produce a logit tensor." , __func__);
3070- }
3089+ // if (logits_mtp) {
3090+ // float * logits_dest = ctx->get_logits_ith(last_tok_idx);
3091+ // ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
3092+ // if (backend_res) {
3093+ // // ggml_backend_tensor_get is the function for GPU->CPU copies.
3094+ // // We are copying a single 32-bit integer.
3095+ // ggml_backend_tensor_get(logits_mtp,
3096+ // logits_dest, // Pointer to our C++ variable
3097+ // 0, // Starting offset in bytes
3098+ // ggml_nbytes(logits_mtp)); // Number of bytes to copy
3099+ // } else {
3100+ // LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
3101+ // }
3102+ // } else {
3103+ // LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
3104+ // }
30713105
3072- ggml_backend_sched_free (sched);
3073- }
3106+ // ggml_backend_sched_free(sched);
3107+ // }
0 commit comments