Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Dockerfile.atlas
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM docker.io/nvidia/cuda:12.8.0-devel-rockylinux9 AS builder
RUN dnf install -y cmake gcc-c++ && dnf clean all
ENV TMPDIR=/llama.cpp/tmp

# Copy local source with inline MTP changes
COPY . /llama.cpp
RUN cd /llama.cpp && \
mkdir -p /llama.cpp/tmp && \
cmake -B build -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=OFF -DCMAKE_CUDA_ARCHITECTURES=120 -DLLAMA_BUILD_TESTS=OFF && \
cmake --build build --target llama-server llama-cli --config Release -j5

FROM docker.io/nvidia/cuda:12.8.0-runtime-rockylinux9
COPY --from=builder /llama.cpp/build/bin/llama-server /usr/local/bin/
COPY --from=builder /llama.cpp/build/bin/llama-cli /usr/local/bin/
RUN mkdir -p /models /templates
EXPOSE 8000
ENTRYPOINT ["/entrypoint.sh"]
7 changes: 5 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3474,8 +3474,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n"
" mtp: use model's built-in Multi-Token Prediction head (requires MTP-capable model)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
Expand All @@ -3490,6 +3491,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else if (value == "mtp") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction (uses model's built-in MTP head)
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
Expand Down
7 changes: 7 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample

result.push_back(id);

fprintf(stderr, "[MTP-VERIFY] pos=%d: sampled=%d, draft=%d, %s\n",
idxs[i], id, draft[i], (draft[i] == id) ? "ACCEPTED" : "REJECTED");
fflush(stderr);

if (draft[i] != id) {
break;
}
Expand All @@ -588,6 +592,9 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
common_sampler_accept(gsmpl, id, true);

result.push_back(id);

fprintf(stderr, "[MTP-VERIFY] bonus pos=%d: sampled=%d\n", idxs[i], id);
fflush(stderr);
}

return result;
Expand Down
107 changes: 104 additions & 3 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include "sampling.h"

#include <algorithm>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <map>
#include <random>

#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
Expand All @@ -21,6 +23,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
Expand All @@ -32,6 +35,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
Expand Down Expand Up @@ -462,6 +466,84 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
}
};

// Multi-Token Prediction (MTP) speculative decoding state
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
bool cooldown = false; // skip proposal after rejection to get fresh MTP logits
std::mt19937 rng{42}; // RNG for temperature sampling of MTP drafts

common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
}

~common_speculative_state_mtp() override = default;

void begin(const llama_tokens & prompt) override {
cooldown = false;
GGML_UNUSED(prompt);
}

void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(prompt_tgt);

// After a draft rejection, MTP logits are from the DRAFT position
// (last in the [sampled, draft] batch), not from the sampled position.
// These logits predict what comes after the draft — which is wrong
// since the draft was rejected. Skip this proposal and let the next
// single-token decode produce fresh MTP logits.
if (cooldown) {
cooldown = false;
return; // empty result = no draft = normal single-token decode
}

const float * mtp_logits = llama_get_mtp_logits(ctx_tgt);
if (mtp_logits == nullptr) {
return;
}

// FastMTP: use reduced vocab size (e.g., 32K instead of 248K)
// Token IDs 0..mtp_n_vocab-1 map directly to full vocab IDs
const int64_t mtp_n_vocab = llama_get_mtp_n_vocab(ctx_tgt);
if (mtp_n_vocab <= 0) {
return;
}

// Argmax of MTP logits over reduced vocabulary
llama_token draft_token = 0;
float best_logit = mtp_logits[0];
for (int64_t i = 1; i < mtp_n_vocab; i++) {
if (mtp_logits[i] > best_logit) {
best_logit = mtp_logits[i];
draft_token = (llama_token)i;
}
}

const auto * vocab = llama_model_get_vocab(llama_get_model(ctx_tgt));
if (!llama_vocab_is_eog(vocab, draft_token)) {
result.push_back(draft_token);
}

GGML_UNUSED(id_last);
GGML_UNUSED(params);
}

void accept(uint16_t n_accepted) override {
// If no drafts were accepted, enter cooldown
// (next draft() call returns empty to force single-token decode)
if (n_accepted == 0) {
cooldown = true;
}
}
};

// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
Expand Down Expand Up @@ -781,6 +863,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
Expand Down Expand Up @@ -822,9 +905,19 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
// Check if the model has MTP layers — for MTP-1, we can use
// checkpoint/restore instead of seq_rm for the 1-token rollback.
// Hybrid SSM models (DeltaNet) support checkpoint/restore via
// llama-memory-recurrent.cpp even though they don't support seq_rm.
const auto * model = llama_get_model(ctx_tgt);
if (model && llama_model_n_mtp_layers(model) > 0) {
LOG_INF("%s: seq_rm not supported, but MTP model detected — using checkpoint/restore for rollback\n", __func__);
// Restore the state we just modified
} else {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
}

done:
Expand Down Expand Up @@ -853,6 +946,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);

bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
Expand Down Expand Up @@ -892,6 +986,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
Expand Down Expand Up @@ -919,6 +1016,10 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type, ctx_tgt));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);

Expand Down
49 changes: 49 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5046,6 +5046,55 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# If model has MTP layers, include them in block_count
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
if mtp_layers > 0:
self.block_count = self.hparams["num_hidden_layers"] + mtp_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
super().set_gguf_parameters()
mtp_layers = self.hparams.get("mtp_num_hidden_layers", 0)
if mtp_layers > 0:
self.gguf_writer.add_nextn_predict_layers(mtp_layers)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp."):
num_hidden = self.hparams["num_hidden_layers"]

if "layers." in name:
# Remap MTP transformer block tensors to append after main layers
# mtp.layers.{k}.* -> model.layers.{k + num_hidden_layers}.*
new_bid = (bid or 0) + num_hidden
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{new_bid}")
yield from super().modify_tensors(data_torch, name, new_bid)
else:
# Shared MTP weights -> nextn tensor slots
from pathlib import Path
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
}
_n = Path(name)
matched = False
for prefix, template in remapper.items():
if name.startswith(prefix):
suffix = name[len(prefix):] # e.g. ".weight"
for b in range(num_hidden, self.block_count):
new_name = template.format(bid=b) + suffix
yield from super().modify_tensors(data_torch, new_name, b)
matched = True
break
if not matched:
# Skip unknown MTP tensors (e.g. embed_tokens/lm_head if shared)
pass
return
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
Expand Down
9 changes: 8 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,7 +1898,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
8 changes: 8 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,9 @@ extern "C" {
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);

// Returns the number of Multi-Token Prediction layers (0 if MTP is not available)
LLAMA_API int32_t llama_model_n_mtp_layers(const struct llama_model * model);

// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

Expand Down Expand Up @@ -988,6 +991,11 @@ extern "C" {
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);

// Get MTP (Multi-Token Prediction) draft logits for the last output position.
// With FastMTP, returns mtp_n_vocab floats (reduced vocabulary). Use llama_get_mtp_n_vocab().
LLAMA_API float * llama_get_mtp_logits(struct llama_context * ctx);
LLAMA_API int64_t llama_get_mtp_n_vocab(struct llama_context * ctx);

// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
Expand Down
22 changes: 14 additions & 8 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,13 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SSM_ALPHA,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
// NextN/MTP tensors
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_QWEN35MOE:
return {
Expand Down Expand Up @@ -2753,14 +2760,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// NextN/MTP tensors — per-layer (appended after main layers)
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
Expand Down
5 changes: 3 additions & 2 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,13 @@ bool llama_batch_allocr::init(
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (batch.token) {
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
// Allow X == Y for speculative decoding where seq_rm + re-eval at same position is valid
if (p0 >= 0 && p0 > seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X < Y\n",
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
Expand Down
Loading