2020const std::vector<enum common_speculative_type> common_speculative_types = {
2121 COMMON_SPECULATIVE_TYPE_NONE,
2222 COMMON_SPECULATIVE_TYPE_DRAFT,
23- COMMON_SPECULATIVE_TYPE_MTP,
2423 COMMON_SPECULATIVE_TYPE_EAGLE3,
2524 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2625 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
@@ -32,7 +31,6 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
3231const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
3332 {" none" , COMMON_SPECULATIVE_TYPE_NONE},
3433 {" draft" , COMMON_SPECULATIVE_TYPE_DRAFT},
35- {" mtp" , COMMON_SPECULATIVE_TYPE_MTP},
3634 {" eagle3" , COMMON_SPECULATIVE_TYPE_EAGLE3},
3735 {" ngram_simple" , COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3836 {" ngram_map_k" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
@@ -146,58 +144,6 @@ struct common_speculative_state {
146144 virtual void accept (uint16_t n_accepted) = 0;
147145};
148146
149- struct common_speculative_state_mtp : public common_speculative_state {
150- llama_context * ctx_tgt;
151- common_sampler * smpl;
152-
153- common_speculative_state_mtp (
154- enum common_speculative_type type,
155- llama_context * ctx_tgt)
156- : common_speculative_state(type)
157- , ctx_tgt(ctx_tgt)
158- {
159- struct common_params_sampling params;
160- params.samplers_sequence = {
161- llama_sampler_type::DIST,
162- };
163- smpl = common_sampler_init (llama_get_model (ctx_tgt), params);
164- }
165-
166- ~common_speculative_state_mtp () override {
167- common_sampler_free (smpl);
168- }
169-
170- void begin (const llama_tokens & prompt) override {
171- GGML_UNUSED (prompt);
172- }
173-
174- void draft (
175- const common_params_speculative & params,
176- const llama_tokens & prompt_tgt,
177- llama_token id_last,
178- llama_tokens & result) override {
179-
180- int32_t n_past = (int32_t )prompt_tgt.size ();
181-
182- llama_seq_id seq_id = 0 ;
183-
184- result = mtp_speculative_gen_draft (
185- smpl,
186- ctx_tgt,
187- params.n_max ,
188- params.p_min ,
189- id_last,
190- n_past,
191- seq_id
192- );
193- }
194-
195- void accept (uint16_t n_accepted) override {
196- GGML_UNUSED (n_accepted);
197- }
198- };
199-
200-
201147struct common_speculative_state_draft : public common_speculative_state {
202148 llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
203149 llama_context * ctx_dft;
@@ -814,7 +760,6 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
814760 switch (type) {
815761 case COMMON_SPECULATIVE_TYPE_NONE: return " none" ;
816762 case COMMON_SPECULATIVE_TYPE_DRAFT: return " draft" ;
817- case COMMON_SPECULATIVE_TYPE_MTP: return " mtp" ;
818763 case COMMON_SPECULATIVE_TYPE_EAGLE3: return " eagle3" ;
819764 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return " ngram_simple" ;
820765 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return " ngram_map_k" ;
@@ -883,7 +828,6 @@ common_speculative * common_speculative_init(
883828 {
884829 bool has_draft = !params.mparams_dft .path .empty ();
885830 bool has_draft_eagle3 = false ; // TODO PR-18039: if params.speculative.eagle3
886- bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
887831
888832 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
889833 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -923,9 +867,6 @@ common_speculative * common_speculative_init(
923867 if (has_ngram_cache) {
924868 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
925869 }
926- if (has_mtp) {
927- configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_MTP, params));
928- }
929870 if (has_draft) {
930871 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_DRAFT, params));
931872 }
@@ -949,12 +890,6 @@ common_speculative * common_speculative_init(
949890 ));
950891 break ;
951892 }
952- case COMMON_SPECULATIVE_TYPE_MTP: {
953- impls.push_back (std::make_unique<common_speculative_state_mtp>(config.type ,
954- /* .ctx_tgt = */ ctx_tgt
955- ));
956- break ;
957- }
958893 case COMMON_SPECULATIVE_TYPE_EAGLE3: {
959894 impls.push_back (std::make_unique<common_speculative_state_eagle3>(config.type ));
960895 break ;
@@ -1112,112 +1047,3 @@ void common_speculative_print_stats(const common_speculative * spec) {
11121047 str_perf.c_str ());
11131048 }
11141049}
1115-
1116- // ----------------------------------------------------------------------------
1117- // MTP
1118- // ----------------------------------------------------------------------------
1119- std::vector<llama_token> mtp_speculative_gen_draft (
1120- struct common_sampler * smpl,
1121- struct llama_context * ctx,
1122- int n_draft,
1123- float p_min,
1124- llama_token id_last,
1125- int32_t n_past,
1126- llama_seq_id seq_id) {
1127-
1128- llama_tokens drafts;
1129- drafts.reserve (n_draft);
1130-
1131- if (!smpl) return drafts;
1132-
1133- common_sampler_reset (smpl);
1134-
1135- llama_batch mtp_batch = llama_batch_init (1 , 0 , 1 );
1136- llama_set_mtp_op_type (ctx, MTP_OP_DRAFT_GEN);
1137-
1138- llama_token current_input_id = id_last;
1139- int32_t current_n_past = n_past;
1140-
1141- for (int i = 0 ; i < n_draft; ++i) {
1142- mtp_batch.n_tokens = 0 ;
1143- common_batch_add (mtp_batch, current_input_id, current_n_past, {seq_id}, true );
1144-
1145- if (llama_decode (ctx, mtp_batch) != 0 ) {
1146- break ;
1147- }
1148-
1149- common_sampler_sample (smpl, ctx, 0 , true );
1150-
1151- const auto * cur_p = common_sampler_get_candidates (smpl, true );
1152-
1153- if (!cur_p || cur_p->size == 0 ) {
1154- break ;
1155- }
1156-
1157- const llama_token id_next = cur_p->data [0 ].id ;
1158- const float prob = cur_p->data [0 ].p ;
1159-
1160- common_sampler_accept (smpl, nullptr , id_next, true );
1161-
1162- if (prob < p_min) {
1163- break ;
1164- }
1165-
1166- drafts.push_back (id_next);
1167-
1168- current_input_id = id_next;
1169- current_n_past++;
1170- }
1171- llama_batch_free (mtp_batch);
1172- llama_set_mtp_op_type (ctx, MTP_OP_NONE);
1173-
1174- // Purge the metadata for the draft tokens.
1175- // This prevents cache state corruption where two cells map to the same logical position.
1176- if (!drafts.empty ()) {
1177- llama_kv_cache_seq_rm (ctx, seq_id, n_past, current_n_past);
1178- }
1179-
1180- return drafts;
1181- }
1182-
1183-
1184- void mtp_update_kv_cache (struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
1185- if (batch.n_tokens == 0 ) {
1186- return ;
1187- }
1188-
1189- LOG_DBG (" [MTP-UPDATE|%s] Updating %d tokens...\n " , is_prompt_warmup ? " PROMPT_WARMUP" : " GEN_ACCEPTED" , batch.n_tokens );
1190-
1191- llama_batch mtp_batch = batch;
1192- if (is_prompt_warmup) {
1193- llama_set_mtp_op_type (ctx, MTP_OP_WARMUP);
1194- } else {
1195- llama_set_mtp_op_type (ctx, MTP_OP_UPDATE_ACCEPTED);
1196- }
1197-
1198- for (int i = 0 ; i < mtp_batch.n_tokens ; ++i) {
1199- mtp_batch.logits [i] = true ;
1200- }
1201- llama_decode (ctx, mtp_batch);
1202- llama_set_mtp_op_type (ctx, MTP_OP_NONE);
1203- }
1204-
1205- void mtp_accept_tokens (
1206- struct llama_context * ctx,
1207- const std::vector<llama_token> & ids,
1208- int32_t n_past_base,
1209- llama_seq_id seq_id
1210- ) {
1211- if (ids.empty ()) {
1212- return ;
1213- }
1214-
1215- llama_batch accepted_batch = llama_batch_init (ids.size (), 0 , 1 );
1216- for (size_t i = 0 ; i < ids.size (); ++i) {
1217- common_batch_add (accepted_batch, ids[i], n_past_base + i, { seq_id }, true );
1218- }
1219-
1220- mtp_update_kv_cache (ctx, accepted_batch, false );
1221-
1222- llama_batch_free (accepted_batch);
1223- }
0 commit comments