Skip to content

Commit 6424016

Browse files
committed
Revert "Add MTP decoding support for GLM-4.x MoE (ikawrakow#1270)"
1 parent 73f40a3 commit 6424016

16 files changed

Lines changed: 200 additions & 811 deletions

common/common.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,14 +1491,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
14911491
params.cuda_params = argv[i];
14921492
return true;
14931493
}
1494-
if (arg == "-mtp" || arg == "--multi-token-prediction") {
1495-
params.has_mtp = true;
1496-
return true;
1497-
}
1498-
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
1499-
params.has_mtp = false;
1500-
return true;
1501-
}
15021494
if (arg == "-draft" || arg == "--draft-params") {
15031495
CHECK_ARG
15041496
params.speculative.params = argv[i];
@@ -2718,8 +2710,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
27182710
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
27192711
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
27202712
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
2721-
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
2722-
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
27232713
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
27242714
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
27252715
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
@@ -3458,7 +3448,6 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
34583448
mparams.validate_quants = params.validate_quants;
34593449
mparams.merge_qkv = params.merge_qkv;
34603450
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
3461-
mparams.mtp = params.has_mtp;
34623451
mparams.split_output_tensor = params.split_output_tensor;
34633452
if (params.kv_overrides.empty()) {
34643453
mparams.kv_overrides = NULL;
@@ -3582,8 +3571,6 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
35823571
cparams.thresh_experts = params.thresh_experts;
35833572
cparams.only_active_experts = params.only_active_exps;
35843573
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
3585-
cparams.mtp = params.has_mtp;
3586-
cparams.mtp_op_type = MTP_OP_NONE;
35873574

35883575
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
35893576
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

common/common.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
139139
enum common_speculative_type {
140140
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
141141
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
142-
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
143142
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
144143
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
145144
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
@@ -359,7 +358,6 @@ struct gpt_params {
359358
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
360359
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
361360
int fused_delta_net = 0; // use fused delta-net if number of tokens in the batch is less than this value
362-
bool has_mtp = false; // enable MTP if supported by the model
363361

364362
std::string cache_type_k = "f16"; // KV cache data type for the K
365363
std::string cache_type_v = "f16"; // KV cache data type for the V

common/speculative.cpp

Lines changed: 0 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
const std::vector<enum common_speculative_type> common_speculative_types = {
2121
COMMON_SPECULATIVE_TYPE_NONE,
2222
COMMON_SPECULATIVE_TYPE_DRAFT,
23-
COMMON_SPECULATIVE_TYPE_MTP,
2423
COMMON_SPECULATIVE_TYPE_EAGLE3,
2524
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2625
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
@@ -32,7 +31,6 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
3231
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
3332
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3433
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
35-
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
3634
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
3735
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3836
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
@@ -146,58 +144,6 @@ struct common_speculative_state {
146144
virtual void accept(uint16_t n_accepted) = 0;
147145
};
148146

149-
struct common_speculative_state_mtp : public common_speculative_state {
150-
llama_context * ctx_tgt;
151-
common_sampler * smpl;
152-
153-
common_speculative_state_mtp(
154-
enum common_speculative_type type,
155-
llama_context * ctx_tgt)
156-
: common_speculative_state(type)
157-
, ctx_tgt(ctx_tgt)
158-
{
159-
struct common_params_sampling params;
160-
params.samplers_sequence = {
161-
llama_sampler_type::DIST,
162-
};
163-
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
164-
}
165-
166-
~common_speculative_state_mtp() override {
167-
common_sampler_free(smpl);
168-
}
169-
170-
void begin(const llama_tokens & prompt) override {
171-
GGML_UNUSED(prompt);
172-
}
173-
174-
void draft(
175-
const common_params_speculative & params,
176-
const llama_tokens & prompt_tgt,
177-
llama_token id_last,
178-
llama_tokens & result) override {
179-
180-
int32_t n_past = (int32_t)prompt_tgt.size();
181-
182-
llama_seq_id seq_id = 0;
183-
184-
result = mtp_speculative_gen_draft(
185-
smpl,
186-
ctx_tgt,
187-
params.n_max,
188-
params.p_min,
189-
id_last,
190-
n_past,
191-
seq_id
192-
);
193-
}
194-
195-
void accept(uint16_t n_accepted) override {
196-
GGML_UNUSED(n_accepted);
197-
}
198-
};
199-
200-
201147
struct common_speculative_state_draft : public common_speculative_state {
202148
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
203149
llama_context * ctx_dft;
@@ -814,7 +760,6 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
814760
switch (type) {
815761
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
816762
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
817-
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
818763
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
819764
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
820765
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
@@ -883,7 +828,6 @@ common_speculative * common_speculative_init(
883828
{
884829
bool has_draft = !params.mparams_dft.path.empty();
885830
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
886-
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
887831

888832
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
889833
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -923,9 +867,6 @@ common_speculative * common_speculative_init(
923867
if (has_ngram_cache) {
924868
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
925869
}
926-
if (has_mtp) {
927-
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
928-
}
929870
if (has_draft) {
930871
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
931872
}
@@ -949,12 +890,6 @@ common_speculative * common_speculative_init(
949890
));
950891
break;
951892
}
952-
case COMMON_SPECULATIVE_TYPE_MTP: {
953-
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
954-
/* .ctx_tgt = */ ctx_tgt
955-
));
956-
break;
957-
}
958893
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
959894
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
960895
break;
@@ -1112,112 +1047,3 @@ void common_speculative_print_stats(const common_speculative * spec) {
11121047
str_perf.c_str());
11131048
}
11141049
}
1115-
1116-
// ----------------------------------------------------------------------------
1117-
// MTP
1118-
// ----------------------------------------------------------------------------
1119-
std::vector<llama_token> mtp_speculative_gen_draft(
1120-
struct common_sampler * smpl,
1121-
struct llama_context * ctx,
1122-
int n_draft,
1123-
float p_min,
1124-
llama_token id_last,
1125-
int32_t n_past,
1126-
llama_seq_id seq_id) {
1127-
1128-
llama_tokens drafts;
1129-
drafts.reserve(n_draft);
1130-
1131-
if (!smpl) return drafts;
1132-
1133-
common_sampler_reset(smpl);
1134-
1135-
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
1136-
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
1137-
1138-
llama_token current_input_id = id_last;
1139-
int32_t current_n_past = n_past;
1140-
1141-
for (int i = 0; i < n_draft; ++i) {
1142-
mtp_batch.n_tokens = 0;
1143-
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
1144-
1145-
if (llama_decode(ctx, mtp_batch) != 0) {
1146-
break;
1147-
}
1148-
1149-
common_sampler_sample(smpl, ctx, 0, true);
1150-
1151-
const auto * cur_p = common_sampler_get_candidates(smpl, true);
1152-
1153-
if (!cur_p || cur_p->size == 0) {
1154-
break;
1155-
}
1156-
1157-
const llama_token id_next = cur_p->data[0].id;
1158-
const float prob = cur_p->data[0].p;
1159-
1160-
common_sampler_accept(smpl, nullptr, id_next, true);
1161-
1162-
if (prob < p_min) {
1163-
break;
1164-
}
1165-
1166-
drafts.push_back(id_next);
1167-
1168-
current_input_id = id_next;
1169-
current_n_past++;
1170-
}
1171-
llama_batch_free(mtp_batch);
1172-
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
1173-
1174-
// Purge the metadata for the draft tokens.
1175-
// This prevents cache state corruption where two cells map to the same logical position.
1176-
if (!drafts.empty()) {
1177-
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
1178-
}
1179-
1180-
return drafts;
1181-
}
1182-
1183-
1184-
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
1185-
if (batch.n_tokens == 0) {
1186-
return;
1187-
}
1188-
1189-
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
1190-
1191-
llama_batch mtp_batch = batch;
1192-
if (is_prompt_warmup) {
1193-
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
1194-
} else {
1195-
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
1196-
}
1197-
1198-
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
1199-
mtp_batch.logits[i] = true;
1200-
}
1201-
llama_decode(ctx, mtp_batch);
1202-
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
1203-
}
1204-
1205-
void mtp_accept_tokens(
1206-
struct llama_context * ctx,
1207-
const std::vector<llama_token> & ids,
1208-
int32_t n_past_base,
1209-
llama_seq_id seq_id
1210-
) {
1211-
if (ids.empty()) {
1212-
return;
1213-
}
1214-
1215-
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
1216-
for (size_t i = 0; i < ids.size(); ++i) {
1217-
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
1218-
}
1219-
1220-
mtp_update_kv_cache(ctx, accepted_batch, false);
1221-
1222-
llama_batch_free(accepted_batch);
1223-
}

common/speculative.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,3 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
3939

4040
// print statistics about the speculative decoding
4141
void common_speculative_print_stats(const common_speculative * spec);
42-
43-
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
44-
std::vector<llama_token> mtp_speculative_gen_draft(
45-
struct common_sampler * smpl,
46-
struct llama_context * ctx,
47-
int n_draft,
48-
float p_min,
49-
llama_token id_last,
50-
int32_t n_past,
51-
llama_seq_id seq_id);
52-
53-
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
54-
55-
void mtp_accept_tokens(
56-
struct llama_context * ctx,
57-
const std::vector<llama_token> & ids,
58-
int32_t n_past_base,
59-
llama_seq_id seq_id
60-
);

examples/mtmd/clip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#include <array>
3636
#include <functional>
3737

38-
#define DEFAULT_INTERPOLATION_MODE ((int)GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS)
38+
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
3939

4040
// TODO: allow to pass callback from user code
4141
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};

0 commit comments

Comments
 (0)