diff --git a/common/arg.cpp b/common/arg.cpp index aad70ec5464..d8604b707f4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3525,6 +3525,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_min_hits = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-use-checkpoints"}, "[on|off|auto]", + string_format("use checkpoints to rewind token history in recurrent models ('on', 'off', or 'auto', default: %s)", + params.speculative.use_checkpoints ? "on" : "off"), + [](common_params & params, const std::string & value) { + if (is_truthy(value) || is_autoy(value)) { + params.speculative.use_checkpoints = true; + } else if (is_falsey(value)) { + params.speculative.use_checkpoints = false; + } else { + throw std::invalid_argument("invalid value for --spec-use-checkpoints"); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 62201ea1ad3..eb451748a63 100644 --- a/common/common.h +++ b/common/common.h @@ -324,6 +324,8 @@ struct common_params_speculative { uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_m = 48; // mgram size for speculative tokens uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models + std::shared_ptr ngram_mod; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index ebf771a24a7..8e3978f7ed0 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -208,7 +208,7 @@ void common_ngram_map_begin( count_keys, count_keys_del, count_values_del, count_map_entries_upd); } - map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.idx_last_check = size_begin; map.size_last_begin = size_begin; } @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - if (map.idx_last_check > cur_len) { + if (map.idx_last_check > cur_len) { // Should not happen because of common_ngram_map_begin(). GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; @@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. // update the value statistics - LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", n_accepted, curr_value.n_accepted); curr_value.n_accepted = n_accepted; } diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e68c38e49c..5a2348d7457 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,10 +144,28 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +struct common_speculative_checkpoint { + llama_pos pos_min = 0; + llama_pos pos_max = 0; + + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + size_t ckpt_size = 0; +}; + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + struct common_speculative_checkpoint ckpt; + bool use_checkpoint; + common_sampler * smpl; llama_batch batch; @@ -160,10 +178,12 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - const std::vector> & replacements) + const std::vector> & replacements, + bool use_checkpoint) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) + , use_checkpoint(use_checkpoint) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -218,7 +238,48 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + if (use_checkpoint && ckpt.size() > 0) { + // delete checkpoint + LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%zu, size=%.3f MiB\n", + __func__, prompt.size(), + ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); + ckpt.pos_min = 0; + ckpt.pos_max = 0; + ckpt.n_tokens = 0; + ckpt.ckpt_size = 0; + ckpt.data.clear(); + } + } + + size_t draft_init_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + int slot_id = 0; + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); + ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); + ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.data.resize(checkpoint_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, + ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); + return n; + } + + size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + int slot_id = 0; + LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, + ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_part_expected) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + } + return n; } void draft( @@ -236,8 +297,8 @@ struct common_speculative_state_draft : public common_speculative_state { auto * mem_dft = llama_get_memory(ctx_dft); - int reuse_i = 0; - int reuse_n = 0; + int reuse_i = 0; // index of part to be reused in prompt_dft + int reuse_n = 0; // length of part to be reused in prompt_dft const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; @@ -287,18 +348,26 @@ struct common_speculative_state_draft : public common_speculative_state { } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", + __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); + if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { + LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", + __func__, reuse_i, reuse_n); + reuse_i = 0; + reuse_n = 0; + } result.clear(); result.reserve(params.n_max); - if (reuse_n == 0) { + bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { // this happens when a previous draft has been discarded (for example, due to being too small), but the // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { result.push_back(prompt_dft[i]); @@ -310,19 +379,50 @@ struct common_speculative_state_draft : public common_speculative_state { return; } + bool do_restore = false; + if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { + // This can happen after a partial acceptance (speculative decoding with checkpoints) + LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", + __func__, prompt_dft.size(), prompt_cur.size()); + prompt_dft.resize(prompt_cur.size()); + do_restore = true; + } + if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (use_checkpoint) { + if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { + LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + } + draft_restore_checkpoint(ckpt.ckpt_size); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); + needs_ckpt = false; + } else { + bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, reuse_n, prompt_dft.size()); + } + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } } } + if (needs_ckpt && use_checkpoint) { + ckpt.ckpt_size = draft_init_checkpoint(prompt_dft.size(), batch.n_tokens); + } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -337,7 +437,11 @@ struct common_speculative_state_draft : public common_speculative_state { if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", + __func__, ret, prompt_cur.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -351,7 +455,11 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, ret, prompt_cur.size(), prompt_dft.size()); + } common_sampler_reset(smpl); @@ -387,7 +495,11 @@ struct common_speculative_state_draft : public common_speculative_state { common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + } prompt_dft.push_back(id); } @@ -909,9 +1021,10 @@ common_speculative * common_speculative_init( break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements, + /* .use_checkpoint= */ params.use_checkpoints )); break; } @@ -1072,3 +1185,285 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf.c_str()); } } + + +// server callbacks +// + +common_speculative_callback::~common_speculative_callback() = default; + +// server session +// +struct common_speculative_session::impl { + common_speculative_callback & callback; + common_params_speculative params_spec; + + llama_context * ctx_tgt = nullptr; + + common_speculative * spec = nullptr; + + // `i_batch_dft`, idx of draft tokens in the main batch are stored in the caller + + llama_tokens draft; + + // use of checkpoints in speculative mode + bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created + uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1) + size_t spec_ckpt_size_part = 0; // size of partial checkpoint + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + impl(common_speculative_callback & callback, + const common_params_speculative & params, + llama_context * ctx_tgt) + : callback(callback), params_spec(params), ctx_tgt(ctx_tgt) { + spec = common_speculative_init(params_spec, ctx_tgt); + } + + void begin(const llama_tokens & prompt_history) { + common_speculative_begin(spec, prompt_history); + } + + bool has_batch_dft() { + return !draft.empty(); + } + + void clear_draft() { + draft.clear(); + spec_ckpt_n_denials = 0; + } + + llama_tokens compute_draft( + const llama_tokens & cached_text_tokens, + llama_token id_last, + const int n_draft_max) { + if (spec == nullptr) { + // no implementation, nothing to do + clear_draft(); + return draft; + } + + if (n_draft_max == 0) { + clear_draft(); + return draft; + } + if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) { + // We shouldn't get two denials. + LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__, + cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size()); + clear_draft(); + return draft; + } + + if (spec_ckpt_n_denials == 1) { + // there is a previous speculation which wasn't accepted in full length + if (draft.empty()) { + // switch to non-draft inference + LOG_DBG("%s: draft of length 0 after denied checkpoint\n", __func__); + clear_draft(); + return draft; + } + // we use the shortened draft of previous speculation + LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__, + cached_text_tokens.size(), id_last, draft.size()); + } else if (spec_ckpt_n_denials > 1) { + GGML_ABORT("illegal state: spec_ckpt_n_denials = %d > 1", spec_ckpt_n_denials); + } else { + // call the speculative implementation to create a draft + draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last); + LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size()); + if (draft.empty()) { + clear_draft(); + return draft; + } + } + + if (draft.size() > (size_t) n_draft_max) { + LOG_WRN("draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); + draft.resize(n_draft_max); + } + + bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints; + if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) { + LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n", + __func__, + cached_text_tokens.size(), + draft.size(), spec_ckpt_n_denials, + do_checkpoint ? "yes" : "no", id_last, + cached_text_tokens[cached_text_tokens.size() - 3], + cached_text_tokens[cached_text_tokens.size() - 2], + cached_text_tokens[cached_text_tokens.size() - 1], + draft[0], draft[1], draft[2]); + } + + if (params_spec.n_min > (int) draft.size()) { + LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min); + clear_draft(); + return draft; + } + + if (do_checkpoint) { + const size_t n = callback.create_checkpoint(); + if (n == 0) { + LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size()); + clear_draft(); + return draft; + } + spec_ckpt_size_part = n; + spec_has_ckpt = true; + } + + // add last sampled token to the batch + callback.batch_add_token(id_last, true); + + // add all drafted tokens to the batch + for (size_t i = 0; i < draft.size(); i++) { + callback.batch_add_token(draft[i], true); + } + + return draft; + } + + common_speculative_accept_response sample_and_accept() { + const size_t n_draft = draft.size(); + + // the accepted tokens from the speculation + auto ids = callback.sampler_sample_and_accept_n(draft); + + LOG_DBG("%s: n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size()); + if (ids.size() < n_draft + 1) { + // the main model rejected some tokens + + // we shorten the draft + draft.resize(ids.size() - 1); + if (spec_has_ckpt) { + // we need to rollback to the state before sampling the draft tokens + // (restore_checkpoint shortens context and slot.prompt.tokens) + const auto ckpt_res = callback.restore_checkpoint(spec_ckpt_size_part); + LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes, pos_max=%d\n", __func__, + ids.size() - 1, n_draft, ckpt_res.bytes_restored, ckpt_res.pos_max); + + // clean orphaned attention KV entries beyond the checkpoint boundary; + // restore_checkpoint only restores recurrent/partial state (PARTIAL_ONLY flag) + { + static const bool skip_seqrm = (getenv("LLAMA_SKIP_SEQRM_AFTER_CKPT") != nullptr); + if (skip_seqrm) { + LOG_WRN("%s: SKIPPING memory_seq_rm after checkpoint restore (debug toggle)\n", __func__); + + // write to file for retrieval when logs are not visible + static FILE * skip_log = fopen("/tmp/seqrm-skips.log", "a"); + if (skip_log) { + fprintf(skip_log, "SKIPPED seq_rm at pos_max=%d\n", ckpt_res.pos_max); + fflush(skip_log); + } + } else { + callback.memory_seq_rm(ckpt_res.pos_max + 1, -1); + } + } + + // delete Checkpoint + callback.delete_checkpoint(); + spec_has_ckpt = false; + + spec_ckpt_n_denials++; + if (ids.size() > 1u + static_cast(params_spec.n_min) && spec_ckpt_n_denials == 1) { + // we will do the batch again but with the shortened draft + //return common_speculative_accept_response(std::move(ids), n_draft, true); + LOG_DBG("%s: partial draft disabled\n", __func__); + } + + LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size()); + draft.clear(); + + // use the sampled token only + ids.resize(1); + // drafted tokens in prompt have been deleted in restore_checkpoint(...). + + // skip acceptance, don't calculate a new draft + return common_speculative_accept_response{std::move(ids), 0, true}; + } + } + const size_t draft_size_accepted = draft.size(); + LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size()); + common_speculative_accept(spec, draft_size_accepted); + draft.clear(); + + return common_speculative_accept_response{std::move(ids), n_draft, false}; + } + + void rewind(const llama_pos p0) { + spec_ckpt_n_denials = 0; + if (spec_has_ckpt) { + // Delete Checkpoint + callback.delete_checkpoint(); + spec_has_ckpt = false; + } + // remove attention KV entries from the bonus token and any + // unaccepted drafts beyond p0 + callback.memory_seq_rm(p0, -1); + } + + void print_stats() const { + if (spec == nullptr) { + return; + } + + common_speculative_print_stats(spec); + } + + void reset() { + if (spec == nullptr) { + return; + } + + clear_draft(); + + spec_has_ckpt = false; + spec_ckpt_size_part = 0; + } +}; + +common_speculative_session::common_speculative_session( + common_speculative_callback & callback, + const common_params_speculative & params, + llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) { +} + +common_speculative_session::~common_speculative_session() { + common_speculative_free(p_impl->spec); + delete p_impl; +} + +void common_speculative_session::begin(const llama_tokens & prompt_history) { + p_impl->begin(prompt_history); +} + +bool common_speculative_session::has_batch_dft() { + return !p_impl->has_batch_dft(); +} + +llama_tokens common_speculative_session::compute_draft( + const llama_tokens & prompt, + llama_token id_last, + int n_draft_max_slot) { + return p_impl->compute_draft(prompt, id_last, n_draft_max_slot); +} + +common_speculative_accept_response common_speculative_session::sample_and_accept() { + return p_impl->sample_and_accept(); +} + +void common_speculative_session::rewind(const llama_pos p0) { + p_impl->rewind(p0); +} + +void common_speculative_session::print_stats() const { + p_impl->print_stats(); +} + +void common_speculative_session::reset() { + p_impl->reset(); +} + diff --git a/common/speculative.h b/common/speculative.h index 876cde3d180..8c1f177e7cd 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -3,6 +3,15 @@ #include "llama.h" #include "common.h" +// common/speculative.h has two interfaces: +// +// 1) struct common_speculative with init, begin, draft, accept and print_stats +// Simple interface, see examples/speculative/speculative.cpp +// +// 2) struct common_speculative_session with struct common_speculative_callback +// Complex interface which supports checkpoints, see tools/server/server-context.cpp +// + struct common_speculative; // comma separated list of all types @@ -39,3 +48,92 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); + + + +// Interactions with server +// + +struct restore_checkpoint_result { + size_t bytes_restored; + llama_pos pos_max; +}; + +// callback implemented by the server +struct common_speculative_callback { + virtual ~common_speculative_callback(); + + // Add a token to the draft sequence. + virtual void batch_add_token(const llama_token token, bool logits) = 0; + + // Sample and accept tokens from the main model. + virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0; + + // Deletes a part of the context. + // Returns true if the memory was modified. + virtual bool memory_seq_rm(llama_pos p0, llama_pos p1) = 0; + + // Creates a checkpoint of the current state of the context. + // Returns the size of the checkpoint in bytes. + virtual size_t create_checkpoint() = 0; + + // Restore a checkpoint previously created by create_checkpoint(). + virtual restore_checkpoint_result restore_checkpoint(size_t ckpt_size_part_expected) = 0; + + // Delete a checkpoint previously created by create_checkpoint(). + virtual void delete_checkpoint() = 0; +}; + +struct common_speculative_accept_response { + llama_tokens tokens; + size_t draft_size_initial; + bool skip_acceptance; + + common_speculative_accept_response(llama_tokens t, size_t draft_size_initial, bool skip) + : tokens(std::move(t)), draft_size_initial(draft_size_initial), skip_acceptance(skip) {} +}; + +// speculative decoding which may use checkpoints to rewind in tokens history +struct common_speculative_session { + + common_speculative_session( + common_speculative_callback & callback, + const common_params_speculative & params, + llama_context * ctx_tgt); + + ~common_speculative_session(); + + // don't copy + common_speculative_session(const common_speculative_session &) = delete; + common_speculative_session & operator=(const common_speculative_session &) = delete; + + + // call once at the beginning of a new generation + // some spec implementations use the prompt history to initialize lookup maps + void begin(const llama_tokens & prompt_history); + + bool has_batch_dft(); + + // do speculative decoding to compute a draft of tokens + llama_tokens compute_draft(const llama_tokens & prompt, + llama_token id_last, + int n_draft_max_slot); + + // check if and how far the current draft is accepted + common_speculative_accept_response sample_and_accept(); + + // rewind (because of a draft not fully accepted) + void rewind(const llama_pos p0); + + // print statistics + void print_stats() const; + + // reset and delete structures + void reset(); + + private: + struct impl; + impl * p_impl; + +}; + diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 01166fac9ce..03ec49d417a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -971,6 +971,45 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & head = sinfo.idxs[s].back() + 1; } + + // debug: scan for duplicate (seq_id, pos) pairs across cells + { + static const bool do_check = (getenv("LLAMA_DEBUG_KV_DUPLICATES") != nullptr); + if (do_check) { + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const auto & cells = v_cells[sinfo.strm[s]]; + + // for each newly written cell, check if another cell has the same (seq_id, pos) + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const auto new_idx = sinfo.idxs[s][ii]; + if (cells.is_empty(new_idx)) { + continue; + } + + const auto new_pos = cells.pos_get(new_idx); + const auto new_seq_id = cells.seq_get(new_idx); + + for (uint32_t j = 0; j < cells.size(); ++j) { + if (j == new_idx || cells.is_empty(j)) { + continue; + } + if (cells.pos_get(j) == new_pos && cells.seq_has(j, new_seq_id)) { + LLAMA_LOG_WARN("%s: DUPLICATE KV cell: cell %u and cell %u both have (seq_id=%d, pos=%d)\n", + __func__, j, new_idx, new_seq_id, new_pos); + + // also write to file for retrieval when logs are not visible + static FILE * dup_log = fopen("/tmp/kv-duplicates.log", "a"); + if (dup_log) { + fprintf(dup_log, "DUPLICATE: cell %u and cell %u both have (seq_id=%d, pos=%d)\n", j, + new_idx, new_seq_id, new_pos); + fflush(dup_log); + } + } + } + } + } + } + } } bool llama_kv_cache::get_can_shift() const { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9de554e9007..545254c4455 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,3 +1,4 @@ + #include "server-context.h" #include "server-common.h" #include "server-http.h" @@ -56,7 +57,8 @@ struct server_slot { // multimodal mtmd_context * mctx = nullptr; - common_speculative * spec = nullptr; + std::unique_ptr spec_callback; + std::unique_ptr spec_session = nullptr; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -145,9 +147,9 @@ struct server_slot { json json_schema; common_sampler_ptr smpl; + common_sampler_ptr smpl_checkpoint; // saved sampler state for spec checkpoint llama_token sampled; // in speculative mode, this is the last accepted token - llama_tokens drafted; // stats size_t n_sent_text = 0; // number of sent text character @@ -164,6 +166,11 @@ struct server_slot { int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted + // Diagnostic counters for speculation debugging + int32_t n_spec_cycles = 0; // Total compute_draft calls + int32_t n_spec_empty = 0; // compute_draft returned empty (no prediction) + int32_t n_spec_skip = 0; // sample_and_accept returned skip (full rejection) + void reset() { SLT_DBG(*this, "%s", "\n"); @@ -177,8 +184,11 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; - drafted.clear(); + if (spec_session != nullptr) { + spec_session->reset(); + } i_batch_dft.clear(); + smpl_checkpoint.reset(); generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -186,6 +196,9 @@ struct server_slot { // clear speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; + n_spec_cycles = 0; + n_spec_empty = 0; + n_spec_skip = 0; task_prev = std::move(task); task.reset(); @@ -259,7 +272,7 @@ struct server_slot { } bool can_speculate() const { - return !!spec; + return !!spec_session; } void add_token(const completion_token_output & token) { @@ -336,9 +349,12 @@ struct server_slot { timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; // Add speculative metrics - if (n_draft_total > 0) { + if (n_draft_total > 0 || n_spec_cycles > 0) { timings.draft_n = n_draft_total; timings.draft_n_accepted = n_draft_accepted; + timings.spec_cycles = n_spec_cycles; + timings.spec_empty = n_spec_empty; + timings.spec_skip = n_spec_skip; } return timings; @@ -391,15 +407,18 @@ struct server_slot { t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; + if (n_draft_total > 0 || n_spec_cycles > 0) { + const float draft_ratio = n_draft_total > 0 ? (float) n_draft_accepted / n_draft_total : 0.0f; SLT_CNT(*this, - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n" + "spec cycles = %5d (empty = %5d, skip = %5d, accept = %5d)\n", + draft_ratio, n_draft_accepted, n_draft_total, n_spec_cycles, n_spec_empty, n_spec_skip, + n_spec_cycles - n_spec_empty - n_spec_skip); } - common_speculative_print_stats(spec); + if (spec_session) { + spec_session->print_stats(); + } } json to_json(bool only_metrics = false) const { @@ -598,8 +617,10 @@ struct server_context_impl { // Clear any sampling context for (server_slot & slot : slots) { - common_speculative_free(slot.spec); - slot.spec = nullptr; + if (slot.spec_session != nullptr) { + slot.spec_session->reset(); + slot.spec_session = nullptr; + } } llama_batch_free(batch); @@ -619,6 +640,94 @@ struct server_context_impl { sleeping = new_state; } + // + // callback for speculative decoding + // + struct server_speculative_callback : public common_speculative_callback { + int slot_id; // store slot.id instead of server_slot & slot + server_context_impl & ctx_impl; + + server_speculative_callback(int slot_id, server_context_impl & ctx_impl) + : slot_id(slot_id), ctx_impl(ctx_impl) {} + + server_slot * get_slot() { + server_slot * slot = ctx_impl.get_slot_by_id(slot_id); + if (slot == nullptr) { + GGML_ABORT("missing slot, slot.id=%d", slot_id); + } + return slot; + } + + void batch_add_token(const llama_token token, bool logits) override { + server_slot * slot = get_slot(); + slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens); + common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits); + slot->prompt.tokens.push_back(token); + } + + std::vector sampler_sample_and_accept_n(const llama_tokens & drafted) override { + const server_slot * slot = get_slot(); + if (slot->i_batch_dft.size() != 1 + drafted.size()) { + GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu", + __func__, slot->i_batch_dft.size(), 1 + drafted.size()); + } + const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted); + + return ids; + } + + bool memory_seq_rm(llama_pos p0, llama_pos p1) override { + return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1); + } + + size_t create_checkpoint() override { + server_slot * slot = get_slot(); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id); + const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt? + const auto & cur_with_size = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max); + auto & cur = cur_with_size.checkpoint; + + // save sampler state alongside KV checkpoint + slot->smpl_checkpoint.reset(common_sampler_clone(slot->smpl.get())); + + SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + return cur_with_size.size; + } + + restore_checkpoint_result restore_checkpoint(size_t ckpt_size_part_expected) override { + server_slot * slot = get_slot(); + auto & ckpt = slot->prompt.checkpoints.back(); + + SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx, + ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_part_expected) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + } + + slot->prompt.tokens.keep_first(ckpt.pos_max + 1); + + // restore sampler state from checkpoint + if (slot->smpl_checkpoint) { + slot->smpl.reset(common_sampler_clone(slot->smpl_checkpoint.get())); + } + + return { n, ckpt.pos_max }; + } + + void delete_checkpoint() override { + server_slot * slot = get_slot(); + slot->prompt.checkpoints.pop_back(); + slot->smpl_checkpoint.reset(); + } + + }; + + // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(const common_params & params) { @@ -645,6 +754,7 @@ struct server_context_impl { add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { + // TODO speculative: move to common/speculative.cpp? SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); const auto & params_spec = params_base.speculative; @@ -755,12 +865,16 @@ struct server_context_impl { const bool can_spec = common_speculative_is_compat(ctx); if (!can_spec) { - SRV_WRN("%s", "speculative decoding not supported by this context\n"); + SRV_WRN("%s", "speculative decoding not supported by this context without checkpoints\n"); } // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; + // Create a new slot in the vector. + slots.emplace_back(); + + // Get a reference of the new slot. + server_slot & slot = slots.back(); slot.id = i; slot.ctx = ctx; @@ -770,17 +884,48 @@ struct server_context_impl { slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (can_spec) { - slot.spec = common_speculative_init(params_base.speculative, slot.ctx); - if (slot.spec) { - if (mctx) { - SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); - return false; + if (can_spec || params_base.speculative.use_checkpoints) { + if (mctx) { + SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); + return false; + } + + auto spec_params = params_base.speculative; + + // disable checkpoints for standard KV cache models — they gain nothing + // from checkpoint save/restore (PARTIAL_ONLY flag is ignored, full KV is + // saved) and the non-checkpoint path (memory_seq_rm + rewind) is correct. + // checkpoints are only useful for models with state that seq_rm cannot + // roll back: recurrent state or hybrid memory. + if (spec_params.use_checkpoints) { + const bool needs_checkpoints = llama_model_is_recurrent(model) || llama_model_is_hybrid(model); + if (!needs_checkpoints) { + spec_params.use_checkpoints = false; + SLT_INF(slot, "%s", "disabled spec checkpoints for standard KV cache model\n"); + } + + // quantized V cache (not f16/bf16) with speculative checkpoints on + // hybrid models causes non-deterministic output. checkpoint restore + // discards attention KV entries beyond the checkpoint position, forcing + // re-computation in a different batch size. flash attention produces + // slightly different floating-point results depending on batch size, + // and coarse V cache quantization (e.g. q4_0, q8_0) amplifies these + // differences across quantization boundaries, causing cascading drift. + // use f16/bf16 V cache or disable checkpoints to avoid this. + if (needs_checkpoints && llama_model_is_hybrid(model) && + params_base.cache_type_v != GGML_TYPE_F16 && params_base.cache_type_v != GGML_TYPE_BF16) { + SLT_WRN(slot, + "quantized V cache (%s) with speculative checkpoints on " + "hybrid model may produce non-deterministic output — " + "consider using f16/bf16 V cache for correctness\n", + ggml_type_name(params_base.cache_type_v)); } - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } else { - SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); } + + slot.spec_callback = std::make_unique(slot.id, *this); + slot.spec_session = + std::make_unique(*slot.spec_callback, spec_params, slot.ctx); + SLT_INF(slot, "%s", "speculative decoding context initialized\n"); } SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); @@ -790,8 +935,6 @@ struct server_context_impl { }; slot.reset(); - - slots.push_back(std::move(slot)); } { @@ -1168,7 +1311,7 @@ struct server_context_impl { backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0); + backend_sampling &= !(slot.spec_session && task.params.speculative.n_max > 0); // TODO: getting post/pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_logits; @@ -1674,6 +1817,43 @@ struct server_context_impl { return true; } + struct server_prompt_checkpoint_with_size { + server_prompt_checkpoint checkpoint; + size_t size; + }; + + // Creates a checkpoint. + // + // n_tokens_cur: the number of tokens added to the batch for the current slot + server_prompt_checkpoint_with_size get_checkpoint(server_slot & slot, const int64_t n_tokens_cur, + llama_pos pos_min, llama_pos pos_max) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur, + /*.data = */ std::vector(checkpoint_size), + }); + + const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + return server_prompt_checkpoint_with_size{ cur, checkpoint_size }; + } + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: @@ -2068,56 +2248,34 @@ struct server_context_impl { // generate draft tokens in speculative decoding mode // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] // perform the speculative drafting for all sequences at the same time in a single batch - const int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - + llama_tokens draft; + const int n_draft_max_slot = slot.get_n_draft_max(); + if (n_draft_max_slot > 0) { const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - - const auto & params_spec = slot.task->params.speculative; - - llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - - if (draft.size() > (size_t) n_draft_max) { - SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); - draft.resize(n_draft_max); + // compute draft and add draft to internal batch + draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot); + slot.n_spec_cycles++; + if (draft.empty()) { + slot.n_spec_empty++; } - - // add the sampled token to the batch - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); - - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding - slot.i_batch = slot.i_batch_dft[0]; - slot.drafted.clear(); - slot.i_batch_dft.clear(); - } else { - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // add all drafted tokens to the batch - for (size_t i = 0; i < draft.size(); i++) { - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(draft[i]); - } - slot.drafted = std::move(draft); + if (draft.size() > 0) { + SLT_DBG(slot, "compute_draft: id=%d, #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n", + slot.sampled, + cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size()); } - } else { + } + + if (draft.empty()) { // no speculative decoding slot.i_batch = batch.n_tokens; + slot.i_batch_dft.clear(); common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + SLT_DBG(slot, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.sampled, slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); } } @@ -2613,6 +2771,7 @@ struct server_context_impl { // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; + SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); @@ -2620,31 +2779,8 @@ struct server_context_impl { // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint. if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, - "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 - ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = - llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, - LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - + auto cur_with_size = get_checkpoint(slot, n_tokens_cur, pos_min, pos_max); + auto & cur = cur_with_size.checkpoint; SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", @@ -2710,6 +2846,18 @@ struct server_context_impl { metrics.on_decoded(slots); if (ret != 0) { + // clear speculative draft state for all slots - the draft tokens in + // this batch will not be decoded, so sampling them would crash + for (auto & slot : slots) { + if (!slot.i_batch_dft.empty()) { + SLT_WRN(slot, "clearing speculative draft state due to decode failure (ret = %d)\n", ret); + slot.i_batch_dft.clear(); + if (slot.spec_session) { + slot.spec_session->reset(); + } + } + } + { std::string err; @@ -2821,7 +2969,7 @@ struct server_context_impl { slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + slot.spec_session->begin(slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2878,24 +3026,25 @@ struct server_context_impl { continue; } - const size_t n_draft = slot.drafted.size(); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + auto accept_response = slot.spec_session->sample_and_accept(); slot.i_batch_dft.clear(); - slot.drafted.clear(); + const size_t n_draft = accept_response.draft_size_initial; + if (accept_response.skip_acceptance) { + slot.n_spec_skip++; + SLT_DBG(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft); + continue; + } + const auto ids = accept_response.tokens; + const int64_t t_current = ggml_time_us(); slot.n_decoded += ids.size(); - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - - // inform the speculative decoding about the number of accepted tokens - common_speculative_accept(slot.spec, ids.size() - 1); + slot.n_draft_total += n_draft; // rollback to the state before sampling the draft tokens slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); @@ -2903,8 +3052,9 @@ struct server_context_impl { // add accepted tokens to the prompt slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + slot.spec_session->rewind(slot.prompt.n_tokens()); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 7d543b9292b..d11f262b42f 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -624,9 +624,12 @@ json result_timings::to_json() const { {"predicted_per_second", predicted_per_second}, }; - if (draft_n > 0) { + if (draft_n > 0 || spec_cycles > 0) { base["draft_n"] = draft_n; base["draft_n_accepted"] = draft_n_accepted; + base["spec_cycles"] = spec_cycles; + base["spec_empty"] = spec_empty; + base["spec_skip"] = spec_skip; } return base; diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a49ddb594b9..5f07e971811 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -274,6 +274,11 @@ struct result_timings { int32_t draft_n = 0; int32_t draft_n_accepted = 0; + // Diagnostic spec counters + int32_t spec_cycles = 0; + int32_t spec_empty = 0; + int32_t spec_skip = 0; + json to_json() const; }; diff --git a/tools/server/tests/unit/test_spec_checkpoint_ci.py b/tools/server/tests/unit/test_spec_checkpoint_ci.py new file mode 100644 index 00000000000..8b6d072ce28 --- /dev/null +++ b/tools/server/tests/unit/test_spec_checkpoint_ci.py @@ -0,0 +1,56 @@ +""" +CI-compatible speculative checkpoint tests using stories15M models. + +These tests exercise the checkpoint code path on a pure transformer model. +They validate infrastructure correctness but cannot trigger hybrid-model-specific +bugs (KV leak, sampler state) because stories15M is not a hybrid model and +the standard KV guard disables checkpoints for it. + +Note: With the standard KV guard (Fix 4), checkpoints are silently disabled +for stories15M. These tests verify that the non-checkpoint fallback path +(memory_seq_rm + rewind) remains correct with checkpoint CLI flags present. +""" + +import pytest +from utils import ServerPreset + + +@pytest.fixture(scope="module", autouse=True) +def do_something(): + """Override conftest's load_all — we download models on demand.""" + pass + + +def test_spec_decoding_with_checkpoint_flags(): + """Spec decoding produces deterministic output with checkpoint flags on stories15M. + + The standard KV guard disables checkpoints for this pure transformer model, + so this tests the fallback path (memory_seq_rm + rewind) with checkpoint + CLI flags present. 10 identical requests must produce identical output. + """ + server = ServerPreset.stories15m_moe() + server.offline = False + server.draft_min = 4 + server.draft_max = 8 + server.extra_args = [ + "--spec-type", "ngram-mod", + "--spec-use-checkpoints", "on", + "--ctx-checkpoints", "4", + ] + server.start() + + outputs = [] + for i in range(10): + res = server.make_request("POST", "/completion", data={ + "prompt": "Once upon a time there was a little girl", + "temperature": 0.0, + "top_k": 1, + "n_predict": 128, + }) + assert res.status_code == 200, f"Request {i+1} failed: {res.status_code}" + outputs.append(res.body["content"]) + + assert len(set(outputs)) == 1, ( + f"Output divergence: {len(set(outputs))} unique outputs across " + f"10 identical requests (stories15M, checkpoint flags)." + ) diff --git a/tools/server/tests/unit/test_spec_checkpoint_crash.py b/tools/server/tests/unit/test_spec_checkpoint_crash.py new file mode 100644 index 00000000000..128164a0c3c --- /dev/null +++ b/tools/server/tests/unit/test_spec_checkpoint_crash.py @@ -0,0 +1,180 @@ +""" +Regression test for server crash on KV cache exhaustion during speculative decode. + +Bug: When decode fails during speculative drafting (e.g. KV cache exhaustion +with concurrent requests), the server does not clear draft state before calling +sample_and_accept(). This leads to GGML_ASSERT(logits != nullptr) because the +failed batch has no logits for the drafted token positions. + +Requires Qwen3.5-9B + 0.8B models. +""" + +import os +import time +import pytest +import requests as req_lib +from concurrent.futures import ThreadPoolExecutor, as_completed +from utils import ServerProcess + +QWEN35_9B = os.environ.get( + "QWEN35_9B_MODEL", + os.path.expanduser("~/Models/Qwen3.5-9B-Q4_K_M.gguf"), +) +QWEN35_08B = os.environ.get( + "QWEN35_08B_MODEL", + os.path.expanduser("~/Models/Qwen3.5-0.8B-BF16.gguf"), +) + +requires_qwen35 = pytest.mark.skipif( + not os.path.exists(QWEN35_9B) or not os.path.exists(QWEN35_08B), + reason="Requires local Qwen3.5-9B and 0.8B models", +) + + +@pytest.fixture(scope="module", autouse=True) +def do_something(): + """Override conftest's load_all — we use local models, not HF presets.""" + pass + + +def _make_chat_request(base_url, messages, max_tokens=400): + """Make a chat request, returning (status, body) or (None, error_str).""" + try: + resp = req_lib.post( + f"{base_url}/v1/chat/completions", + json={ + "model": "test", + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0, + }, + timeout=300, + ) + return resp.status_code, resp.json() + except Exception as e: + return None, str(e) + + +@requires_qwen35 +def test_server_survives_kv_exhaustion_with_ngram(): + """Server must not crash when KV cache fills during ngram speculation. + + Sends concurrent multi-turn requests to fill KV cache, triggering decode + failure during speculative drafting. Without the fix, this crashes with + GGML_ASSERT(logits != nullptr) in sampling.cpp. + """ + server = ServerProcess() + server.model_hf_repo = None + server.model_hf_file = None + server.model_file = QWEN35_9B + server.n_ctx = 4096 + server.n_slots = 4 + server.n_gpu_layer = 99 + server.seed = 3407 + server.temperature = 0.0 + server.jinja = True + server.draft_max = 48 + server.extra_args = [ + "--spec-type", "ngram-mod", + "--spec-use-checkpoints", "on", + "--ctx-checkpoints", "4", + "--no-warmup", + "--chat-template-kwargs", '{"enable_thinking": false}', + ] + server.fa = None + server.start(timeout_seconds=120) + base_url = f"http://{server.server_host}:{server.server_port}" + + # phase 1: build ngram data + for i in range(3): + _make_chat_request( + base_url, + [{"role": "user", "content": f"Write quicksort in Python variant {i}. Full code."}], + max_tokens=500, + ) + if server.process.poll() is not None: + break + + # phase 2: parallel multi-turn to fill KV cache + if server.process.poll() is None: + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + futures.append(executor.submit( + _make_chat_request, + base_url, + [ + {"role": "user", "content": f"Write quicksort variant {i}"}, + {"role": "assistant", "content": "def quicksort(a,l=0,h=None):\n if h is None: h=len(a)-1\n if l None: server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") + if self.extra_args: + server_args.extend(self.extra_args) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}")