Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2444,12 +2444,64 @@ struct server_context_impl {
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);

if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
SLT_WRN(slot, "failed to truncate tokens with position >= %d, searching for a checkpoint to restore\n", p0);

// for hybrid/recurrent models, seq_rm fails because the recurrent
// memory has no cell-level checkpoint at position p0.
// try to restore the nearest server-level checkpoint before p0.
bool restored = false;

if (!slot.prompt.checkpoints.empty()) {
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&](const auto & cur) {
return cur.pos_max < p0;
}
);

if (it != slot.prompt.checkpoints.rend()) {
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

if (n == checkpoint_size) {
const llama_pos pos_restored = std::max(it->pos_min + 1, it->pos_max);
size_t n_past_new = std::min(slot.prompt.tokens.size_up_to_pos(pos_restored), (size_t) it->n_tokens);

// [TAG_PROMPT_LOGITS] guarantee at least 1 token for evaluation
if (n_past_new == (size_t) slot.task->n_tokens() && n_past_new > 0) {
n_past_new--;
}

SLT_WRN(slot, "restored checkpoint at seq_rm failure (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB), n_past: %d -> %zu\n",
it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024,
slot.prompt.n_tokens(), n_past_new);

slot.prompt.tokens.keep_first(n_past_new);
slot.n_prompt_tokens_cache = n_past_new;
restored = true;

// erase checkpoints beyond the restored position
for (auto cit = slot.prompt.checkpoints.begin(); cit != slot.prompt.checkpoints.end();) {
if (cit->pos_max >= p0) {
cit = slot.prompt.checkpoints.erase(cit);
} else {
++cit;
}
}
} else {
SLT_ERR(slot, "failed to restore checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
}

slot.prompt_clear(true);
if (!restored) {
SLT_WRN(slot, "no suitable checkpoint found for p0 = %d, clearing the memory\n", p0);

// there is no common part left
slot.n_prompt_tokens_cache = 0;
slot.prompt_clear(true);
slot.n_prompt_tokens_cache = 0;
}
}

bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
Expand Down Expand Up @@ -2599,11 +2651,20 @@ struct server_context_impl {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
{
// for hybrid/recurrent models, even short prompts need checkpoints
// to enable multi-turn cache reuse (recurrent state can't be rolled
// back without one). SWA-only models can use a higher threshold.
const bool is_hybrid_or_recurrent = llama_model_is_hybrid(model) || llama_model_is_recurrent(model);
const int min_pos_max = is_hybrid_or_recurrent ? 0 : 64;
const int min_spacing = is_hybrid_or_recurrent ? 0 : 64;

// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= min_pos_max);

// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + min_spacing);
}

// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
Expand Down