Skip to content

Commit fe2baf5

Browse files
SamuelOliveiradsF1LM1
authored andcommitted
Squashed commit of the following:
commit 912ed2cd9339d1b2875d98744ca5b51fa62e581e Author: samuel <[email protected]> Date: Sun Dec 7 23:00:29 2025 -0300 speculative (feat): implement recursive MTP drafting for GLM-4.5 commit bdf72d9552e3da64ffc85f175664713388752914 Author: samuel <[email protected]> Date: Sat Dec 6 16:10:16 2025 -0300 sampling (feat): optimize speculative drafting with fast-path selection commit a91980a8f3475a6bbac0a64d8be06dd4b613020e Author: samuel <[email protected]> Date: Sat Dec 6 15:18:19 2025 -0300 mtp (chore): clean old code commit 6de0ecf55db8567db4faa99b0152b72c9e854548 Author: samuel <[email protected]> Date: Sat Dec 6 14:40:13 2025 -0300 mtp (feat): add mtp arg commit ea77394183b8e6c368af969b8274039a54b11486 Author: samuel <[email protected]> Date: Sat Dec 6 13:47:54 2025 -0300 mtp-graph (fix): move llama_get_logits_ith outside the loop commit 15dff208958fb66802f20ec53ce5fcaff133edb7 Merge: 171346c74 cae85fe53 Author: samuel <[email protected]> Date: Thu Oct 16 13:44:41 2025 -0300 Merge branch 'glm4-mtp-batch' of https://github.com/SamuelOliveirads/llama.cpp into glm4-mtp-graph-cache commit cae85fe531876762ee02524fc4c3f6c5e7824c63 Author: samuel <[email protected]> Date: Thu Oct 16 13:42:31 2025 -0300 mtp-batch(fix): avoid logits for mtp kv cache operations commit 171346c742c310bbcfbd786b61250638ccf8b44d Author: samuel <[email protected]> Date: Sun Oct 12 16:33:01 2025 -0300 mtp-graph(feat): Reactivate graph reuse only for main model path commit 0127c6beeb384ec3abbc18b22dbe830f22fcf4b4 Author: samuel <[email protected]> Date: Sat Oct 11 22:20:54 2025 -0300 mtp-batch(chore): Remove final MTP debug logs and dead code commit 4bcc9e261ef57ee4cfaa65d06bcd0fcdeacf7797 Author: samuel <[email protected]> Date: Sat Oct 11 18:51:22 2025 -0300 mtp-batch(fix): Correctly advance cache head and add MTP documentation commit b4cbe030ac25056717763b812d1dd89681c08522 Author: samuel <[email protected]> Date: Sat Oct 11 18:37:40 2025 -0300 mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs commit a99709d0c1401d0b447dce1bd0101fb56390f50e Author: samuel <[email protected]> Date: Fri Oct 10 17:24:34 2025 -0300 mtp-batch(refactor): Extract decode context and MTP input logic into helper methods commit 913af8f48d2dab1d9e907cf6c48c921a229a295c Author: samuel <[email protected]> Date: Fri Oct 10 16:44:28 2025 -0300 mtp-batch(refactor): Replace MTP boolean flags with an explicit operation enum commit 6f74ba38070d62d37bc0fb71ce9871e1a4ffabcc Author: samuel <[email protected]> Date: Thu Oct 9 22:27:18 2025 -0300 mtp-batch (fix): prevent mtp draft from polluting the cache commit 5e1d719beffccf8c22784c24b52ff6f5ab56b9ff Author: samuel <[email protected]> Date: Thu Oct 9 15:21:23 2025 -0300 mtp-batch (feat): Create and manage sinfo for MTP commit febd8235d27fe9174ee4b54ea7a10e630939fee0 Author: samuel <[email protected]> Date: Sun Oct 5 14:43:40 2025 -0300 mtp-batch (wip): fix how to warmup kv cache for MTP commit 67c6c069e0a5496adfd7d8aa6ca7514db5a6f437 Author: samuel <[email protected]> Date: Sat Sep 27 19:42:32 2025 -0300 mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption commit 75dc25e6fe781c1b65038d69390fb778d760e3a1 Author: samuel <[email protected]> Date: Sat Sep 27 17:17:00 2025 -0300 mtp-batch (wip): organize batch for mtp cache commit 3da7e7f3309dbb576538850c92c1cbf8fdc6d6ee Author: samuel <[email protected]> Date: Tue Sep 23 22:45:11 2025 -0300 mtp-batch (fix): warm mtp cache for small batch size commit df64508b937784112168aa099644b60fef015f05 Author: samuel <[email protected]> Date: Sun Sep 21 21:55:41 2025 -0300 mtp-batch (wip): merge glm graphs commit 042eb8a829876ed175320df9c8133bcea0c40460 Author: samuel <[email protected]> Date: Sun Sep 21 21:29:00 2025 -0300 mtp-batch (wip): merge mtp and model graph commit 1318b2de82716710b9853e07bd640443a5a025bb Author: samuel <[email protected]> Date: Sun Sep 14 10:22:59 2025 -0300 mtp-batch (wip): move mtp execution to batch format commit c6237c71ffd4485df1c35829c380b63e472fc5dd Merge: 9fab53e43 8742ce0e3 Author: Aaron Lee <[email protected]> Date: Sat Sep 13 02:57:01 2025 -0400 Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp feat: implemented sampling for MTP commit 8742ce0e39823eeb101bb5b6099ff4ca7be10c6e Author: samuel <[email protected]> Date: Sat Sep 6 00:21:18 2025 -0300 feat: apply logits + greedy sampler commit 5a5bce85777041d841393b4396e28f8e3065bb10 Author: samuel <[email protected]> Date: Wed Sep 3 17:56:14 2025 -0300 fix: add sample acceptance commit 07670a22c63b1fa335d6ec1c4a1e4255a920848c Author: samuel <[email protected]> Date: Wed Sep 3 13:25:21 2025 -0300 feat: implemented sampling for MTP commit 9fab53e4388c20aef497efd82e86dcb99ca58064 Author: Aaron Lee <[email protected]> Date: Tue Sep 2 17:14:09 2025 -0400 fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch commit 98bc0c6bf223f425f4ecea14f13fc46101f1b44a Author: Aaron Lee <[email protected]> Date: Tue Aug 26 01:26:51 2025 -0400 replace standard sampler with greedy sampler for mtp draft commit 471e026327cca9f6f58aeefe32129a6cb9390f4f Author: Aaron Lee <[email protected]> Date: Tue Aug 19 23:10:56 2025 -0400 fixed vram leak commit d72f9d5691054958cd1b139f228e5e588d3974cf Author: Aaron Lee <[email protected]> Date: Tue Aug 19 01:50:34 2025 -0400 kludge-y kv cache management of mtp layer commit 382135aa3619294ab8bf87b0de4b1255ab7942f0 Author: Aaron Lee <[email protected]> Date: Sun Aug 17 21:54:45 2025 -0400 fixed mtp kv cache update sequencing after prompt processing commit 6870f9790c1bb1d0254241267b1a6c8a7fc82830 Author: Aaron Lee <[email protected]> Date: Sun Aug 17 04:59:36 2025 -0400 added proper KV cache management for MTP layers and slightly refactored commit 6e9bafc7a738b4c99f9440c0ec461e08cf6ce702 Author: Aaron Lee <[email protected]> Date: Fri Aug 15 23:13:56 2025 -0400 failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable commit cf0f7c0448c2c1736588673114558e5829db7879 Author: Aaron Lee <[email protected]> Date: Wed Aug 13 02:21:17 2025 -0400 broad thrust of the mtp implementation commit 03231da69eec20677e25e2307d4fe31ac2ede034 Author: Aaron Lee <[email protected]> Date: Tue Aug 12 01:03:59 2025 -0400 add model member function to build mtp graph, to be called from speculative.cpp commit 1f477b375504aa557ed21066aa6783b11781a179 Author: Aaron Lee <[email protected]> Date: Mon Aug 11 20:54:45 2025 -0400 make nextn weights loadable without a crash commit e434f87cc739a1901931d88e33f777170a4e18e7 Author: Aaron Lee <[email protected]> Date: Mon Aug 11 01:21:47 2025 -0400 some work towards building mtp layer graph commit db60623e7926fb151b3cc63f029929122cac342a Author: Aaron Lee <[email protected]> Date: Sun Aug 10 23:52:54 2025 -0400 added getter for nextn layer count and server slot has_mtp property
1 parent e1f15b4 commit fe2baf5

18 files changed

+1037
-280
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3214,6 +3214,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32143214
params.speculative.cache_type_k = kv_cache_type_from_str(value);
32153215
}
32163216
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
3217+
add_opt(common_arg(
3218+
{"-mtp", "--multi-token-prediction"},
3219+
string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"),
3220+
[](common_params & params) {
3221+
params.mtp = true;
3222+
}
3223+
));
32173224
add_opt(common_arg(
32183225
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
32193226
string_format(

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ struct common_params {
430430
bool no_op_offload = false; // globally disable offload host tensor operations to device
431431
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
432432
bool no_host = false; // bypass host buffer allowing extra buffers to be used
433+
bool mtp = false; // use mtp is supported
433434

434435
bool single_turn = false; // single turn chat conversation
435436

common/sampling.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,42 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
666666

667667
return samplers;
668668
}
669+
670+
/**
671+
* Specialized sampling for speculative drafting.
672+
*
673+
* Prioritizes performance by using a direct ArgMax loop (Greedy) when no
674+
* penalties (repetition, frequency, presence, DRY) are configured.
675+
* Falls back to the full sampler chain if penalties are active to prevent
676+
* generative loops or adhere to constraints.
677+
*/
678+
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
679+
const auto & params = gsmpl->params;
680+
681+
bool use_heavy_sampler =
682+
(params.penalty_last_n > 0 && (
683+
params.penalty_repeat != 1.0f ||
684+
params.penalty_freq != 0.0f ||
685+
params.penalty_present != 0.0f
686+
)) ||
687+
(params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f);
688+
689+
if (use_heavy_sampler) {
690+
return common_sampler_sample(gsmpl, ctx, idx, false);
691+
}
692+
693+
float * logits = llama_get_logits_ith(ctx, idx);
694+
const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx)));
695+
696+
int best_id = 0;
697+
float max_val = logits[0];
698+
699+
for (int i = 1; i < n_vocab; ++i) {
700+
if (logits[i] > max_val) {
701+
max_val = logits[i];
702+
best_id = i;
703+
}
704+
}
705+
706+
return best_id;
707+
}

common/speculative.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,116 @@ llama_tokens common_speculative_gen_draft(
359359
}
360360
return result;
361361
}
362+
363+
llama_tokens mtp_speculative_gen_draft(
364+
struct common_sampler* smpl,
365+
struct llama_context* ctx,
366+
struct common_speculative_params params,
367+
llama_token id_last,
368+
int32_t n_past,
369+
llama_seq_id seq_id) {
370+
371+
int n_draft = params.n_draft;
372+
373+
llama_tokens drafts;
374+
drafts.reserve(n_draft);
375+
376+
if (!smpl) return drafts;
377+
378+
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
379+
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
380+
381+
llama_token current_input_id = id_last;
382+
int32_t current_n_past = n_past;
383+
384+
for (int i = 0; i < n_draft; ++i) {
385+
mtp_batch.n_tokens = 0;
386+
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
387+
388+
// Perform the MTP draft generation decode. This writes the MTP layer's
389+
// KV state for the draft token into the cache.
390+
if (llama_decode(ctx, mtp_batch) != 0) {
391+
break;
392+
}
393+
394+
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0);
395+
396+
// Drafting stops if token probability drops below `p_min` to save compute.
397+
const auto * cur_p = common_sampler_get_candidates(smpl, true);
398+
if (cur_p && cur_p->size > 0) {
399+
float prob = cur_p->data[0].p;
400+
401+
if (prob < params.p_min) {
402+
drafts.push_back(id_next);
403+
current_n_past++;
404+
break;
405+
}
406+
}
407+
408+
drafts.push_back(id_next);
409+
410+
current_input_id = id_next;
411+
current_n_past++;
412+
}
413+
llama_batch_free(mtp_batch);
414+
415+
// CRITICAL: Purge the metadata for the draft token we just wrote.
416+
// This makes the physical cell available again for the main model's validation pass,
417+
// preventing a cache state corruption where two cells map to the same logical position.
418+
if (!drafts.empty()) {
419+
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
420+
}
421+
422+
return drafts;
423+
}
424+
425+
426+
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
427+
if (batch.n_tokens == 0) {
428+
return;
429+
}
430+
431+
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
432+
433+
llama_batch mtp_batch = batch;
434+
if (is_prompt_warmup) {
435+
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
436+
} else {
437+
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
438+
}
439+
440+
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
441+
mtp_batch.logits[i] = true;
442+
}
443+
llama_decode(ctx, mtp_batch);
444+
}
445+
446+
void mtp_accept_tokens(
447+
struct llama_context * ctx,
448+
const std::vector<llama_token> & ids,
449+
int32_t n_past_base,
450+
llama_seq_id seq_id
451+
) {
452+
if (ids.empty()) {
453+
return;
454+
}
455+
456+
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
457+
// This sets up the context for a "forced sinfo" decode.
458+
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
459+
return;
460+
}
461+
462+
// Build a new batch containing only the accepted tokens.
463+
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
464+
for (size_t i = 0; i < ids.size(); ++i) {
465+
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
466+
}
467+
468+
mtp_update_kv_cache(ctx, accepted_batch, false);
469+
470+
// Clean up the forced state to not affect subsequent, normal decode calls.
471+
llama_mtp_cancel_sinfo_update(ctx);
472+
473+
llama_batch_free(accepted_batch);
474+
}

common/speculative.h

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ struct common_speculative_params {
1212
float p_min = 0.75f; // min probability required to accept a token in the draft
1313
};
1414

15+
struct mtp_kv_update_data {
16+
llama_token id;
17+
int32_t n_past;
18+
int32_t tok_idx;
19+
};
20+
1521
struct common_speculative * common_speculative_init(
1622
struct llama_context * ctx_tgt,
1723
struct llama_context * ctx_dft
@@ -29,7 +35,40 @@ void common_speculative_add_replacement_tgt_dft(
2935

3036
// sample up to n_draft tokens and add them to the batch using the draft model
3137
llama_tokens common_speculative_gen_draft(
32-
struct common_speculative * spec,
33-
struct common_speculative_params params,
34-
const llama_tokens & prompt,
35-
llama_token id_last);
38+
struct common_speculative * spec,
39+
struct common_speculative_params params,
40+
const llama_tokens & prompt,
41+
llama_token id_last);
42+
43+
/**
44+
* @brief Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
45+
*
46+
* This function performs a recursive generation loop using the MTP head (e.g., Eagle/NextN).
47+
* It uses the fixed hidden state from the main model's last step and updates the MTP layer's
48+
* internal KV cache autoregressively.
49+
*
50+
* @param smpl The sampler instance.
51+
* @param ctx The llama context (shared between Main and MTP).
52+
* @param params Speculative parameters (n_draft, p_min).
53+
* @param id_last The last confirmed token ID from the main model.
54+
* @param n_past The number of tokens in the validated past (start position for drafting).
55+
* @param seq_id The sequence ID to use for drafting.
56+
*
57+
* @return std::vector<llama_token> The generated draft tokens.
58+
*/
59+
llama_tokens mtp_speculative_gen_draft(
60+
struct common_sampler* smpl,
61+
struct llama_context* ctx,
62+
struct common_speculative_params params,
63+
llama_token id_last,
64+
int32_t n_past,
65+
llama_seq_id seq_id);
66+
67+
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
68+
69+
void mtp_accept_tokens(
70+
struct llama_context * ctx,
71+
const std::vector<llama_token> & ids,
72+
int32_t n_past_base,
73+
llama_seq_id seq_id
74+
);

include/llama.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ extern "C" {
228228
// - if not: only the last token is output
229229
// )
230230
//
231+
typedef enum {
232+
MTP_OP_NONE,
233+
MTP_OP_WARMUP,
234+
MTP_OP_UPDATE_ACCEPTED,
235+
MTP_OP_DRAFT_GEN,
236+
} llama_mtp_op_type;
237+
238+
typedef struct llama_mtp_params {
239+
llama_mtp_op_type op_type;
240+
} llama_mtp_params;
241+
231242
typedef struct llama_batch {
232243
int32_t n_tokens;
233244

@@ -237,6 +248,7 @@ extern "C" {
237248
int32_t * n_seq_id;
238249
llama_seq_id ** seq_id;
239250
int8_t * logits; // TODO: rename this to "output"
251+
llama_mtp_params mtp_params;
240252
} llama_batch;
241253

242254
enum llama_model_kv_override_type {
@@ -536,6 +548,8 @@ extern "C" {
536548

537549
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
538550

551+
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
552+
539553
// Functions to access the model's GGUF metadata scalar values
540554
// - The functions return the length of the string on success, or -1 on failure
541555
// - The output string is always null-terminated and cleared on failure
@@ -1442,6 +1456,38 @@ extern "C" {
14421456
ggml_opt_epoch_callback callback_train,
14431457
ggml_opt_epoch_callback callback_eval);
14441458

1459+
//
1460+
// MTP
1461+
//
1462+
1463+
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
1464+
1465+
/**
1466+
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
1467+
* This is used after speculative validation when only a subset of draft tokens are accepted.
1468+
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
1469+
* @return true on success.
1470+
*/
1471+
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
1472+
1473+
/**
1474+
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
1475+
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
1476+
* @return true on success.
1477+
*/
1478+
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
1479+
1480+
/**
1481+
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
1482+
*/
1483+
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
1484+
1485+
/**
1486+
* @brief Removes KV cache metadata for a specified sequence and token range.
1487+
* This makes the physical cells logically available again without deleting the tensor data.
1488+
*/
1489+
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
1490+
14451491
#ifdef __cplusplus
14461492
}
14471493
#endif

src/llama-arch.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,12 +2370,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
23702370
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
23712371
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
23722372
// These tensors only exist in the last layer(s) and are treated as output tensors
2373-
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2374-
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2375-
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2376-
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
2377-
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2378-
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
2373+
// Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id
2374+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2375+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
2376+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
2377+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2378+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2379+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
23792380
};
23802381

23812382
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-batch.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,17 +301,17 @@ bool llama_batch_allocr::init(
301301
ok = false;
302302
}
303303

304-
if (!ok) {
305-
LLAMA_LOG_ERROR(
306-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
310-
__func__, s, s, p0, s, seq_pos_min(s));
304+
// if (!ok) {
305+
// LLAMA_LOG_ERROR(
306+
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307+
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308+
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309+
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
310+
// __func__, s, s, p0, s, seq_pos_min(s));
311311

312-
return false;
313-
}
314-
}
312+
// return false;
313+
// }
314+
}
315315

316316
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
317317
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
@@ -874,13 +874,14 @@ struct llama_batch llama_batch_get_one(
874874

875875
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
876876
llama_batch batch = {
877-
/*n_tokens =*/ 0,
878-
/*tokens =*/ nullptr,
879-
/*embd =*/ nullptr,
880-
/*pos =*/ nullptr,
881-
/*n_seq_id =*/ nullptr,
882-
/*seq_id =*/ nullptr,
883-
/*logits =*/ nullptr,
877+
/*n_tokens =*/ 0,
878+
/*tokens =*/ nullptr,
879+
/*embd =*/ nullptr,
880+
/*pos =*/ nullptr,
881+
/*n_seq_id =*/ nullptr,
882+
/*seq_id =*/ nullptr,
883+
/*logits =*/ nullptr,
884+
/*.mtp_params =*/ { MTP_OP_NONE },
884885
};
885886

886887
if (embd) {

0 commit comments

Comments
 (0)