Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ad61684
wip: port MTP architecture
SamuelOliveirads Dec 28, 2025
b75f70e
Refactors `server_slot` to support generic speculative decoding (MTP …
SamuelOliveirads Dec 29, 2025
f9c4f6c
core: enable hybrid outputs (logits + embeddings) for MTP support
SamuelOliveirads Dec 29, 2025
b61daeb
fix(mtp): correct KV-cache slot finding for updates
SamuelOliveirads Jan 1, 2026
c03ae51
fix(mtp): persist hidden states to prevent context corruption during …
SamuelOliveirads Jan 2, 2026
ab6f4bb
refactor(mtp): clean unused code
SamuelOliveirads Feb 5, 2026
ec2d1a0
fix(mtp): update server to new functions name
SamuelOliveirads Feb 7, 2026
9317463
fix(mtp): fix graph and save hidden state
SamuelOliveirads Feb 8, 2026
d3465f1
mtp: refactor integration, context params and kv cache search
SamuelOliveirads Feb 9, 2026
2539f4f
mtp: fix hidden state extraction and speculative acceptance flow
SamuelOliveirads Feb 9, 2026
07e4936
server: fix MTP warmup for long prompts and reset token buffer
SamuelOliveirads Feb 12, 2026
d088faa
llama: refactor MTP operation state to context parameters
SamuelOliveirads Feb 13, 2026
97ec50e
server: fix n_past calculation in MTP acceptance
SamuelOliveirads Feb 13, 2026
573170e
llama: fix mtp enable flags
SamuelOliveirads Feb 13, 2026
5260bf2
Merge branch 'main' into feat-glm-mtp
SamuelOliveirads Feb 20, 2026
b4a2c88
speculative: refactor MTP to use common_speculative interface
SamuelOliveirads Feb 20, 2026
b8f27f3
context: remove unused signatures
SamuelOliveirads Feb 20, 2026
dd684fb
clip: fix deprecated enum-enum conversion warning
SamuelOliveirads Feb 20, 2026
0bcee4e
common: fix format string crash in help message
SamuelOliveirads Feb 20, 2026
1d5b287
context: fix mtp activation logic
SamuelOliveirads Feb 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cuda_params = argv[i];
return true;
}
if (arg == "-mtp" || arg == "--multi-token-prediction") {
params.has_mtp = true;
return true;
}
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
params.has_mtp = false;
return true;
}
if (arg == "-draft" || arg == "--draft-params") {
CHECK_ARG
params.speculative.params = argv[i];
Expand Down Expand Up @@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
Expand Down Expand Up @@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
mparams.validate_quants = params.validate_quants;
mparams.merge_qkv = params.merge_qkv;
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
mparams.mtp = params.has_mtp;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
Expand Down Expand Up @@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
cparams.thresh_experts = params.thresh_experts;
cparams.only_active_experts = params.only_active_exps;
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
cparams.mtp = params.has_mtp;
cparams.mtp_op_type = MTP_OP_NONE;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
Expand Down Expand Up @@ -356,6 +357,7 @@ struct gpt_params {
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
bool has_mtp = false; // enable MTP if supported by the model

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
Expand Down
174 changes: 174 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
Expand All @@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
Expand Down Expand Up @@ -144,6 +146,58 @@ struct common_speculative_state {
virtual void accept(uint16_t n_accepted) = 0;
};

struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
common_sampler * smpl;

common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
struct common_params_sampling params;
params.samplers_sequence = {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
}

~common_speculative_state_mtp() override {
common_sampler_free(smpl);
}

void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}

void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {

int32_t n_past = (int32_t)prompt_tgt.size();

llama_seq_id seq_id = 0;

result = mtp_speculative_gen_draft(
smpl,
ctx_tgt,
params.n_max,
params.p_min,
id_last,
n_past,
seq_id
);
}

void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};


struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
Expand Down Expand Up @@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
Expand Down Expand Up @@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);

bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
Expand Down Expand Up @@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
Expand All @@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
/* .ctx_tgt = */ ctx_tgt
));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
Expand Down Expand Up @@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf.c_str());
}
}

// ----------------------------------------------------------------------------
// MTP
// ----------------------------------------------------------------------------
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {

llama_tokens drafts;
drafts.reserve(n_draft);

if (!smpl) return drafts;

common_sampler_reset(smpl);

llama_batch mtp_batch = llama_batch_init(1, 0, 1);
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);

llama_token current_input_id = id_last;
int32_t current_n_past = n_past;

for (int i = 0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);

if (llama_decode(ctx, mtp_batch) != 0) {
break;
}

common_sampler_sample(smpl, ctx, 0, true);

const auto * cur_p = common_sampler_get_candidates(smpl, true);

if (!cur_p || cur_p->size == 0) {
break;
}

const llama_token id_next = cur_p->data[0].id;
const float prob = cur_p->data[0].p;

common_sampler_accept(smpl, nullptr, id_next, true);

if (prob < p_min) {
break;
}

drafts.push_back(id_next);

current_input_id = id_next;
current_n_past++;
}
llama_batch_free(mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);

// Purge the metadata for the draft tokens.
// This prevents cache state corruption where two cells map to the same logical position.
if (!drafts.empty()) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
}

return drafts;
}


void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return;
}

LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);

llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
}

for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
}

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}

llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}

mtp_update_kv_cache(ctx, accepted_batch, false);

llama_batch_free(accepted_batch);
}
19 changes: 19 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,22 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);

// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id);

void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);
2 changes: 1 addition & 1 deletion examples/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include <array>
#include <functional>

#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
#define DEFAULT_INTERPOLATION_MODE ((int)GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS)

// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
Expand Down
Loading