Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,7 @@ struct server_slot {
std::string stopping_word;

// state
int n_prompt_checkpoints_made = 0; // New: Number of checkpoints made during prompt processing
slot_state state = SLOT_STATE_IDLE;

server_prompt prompt;
Expand Down Expand Up @@ -1743,6 +1744,7 @@ struct server_slot {
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
n_prompt_checkpoints_made = 0; // New: Reset on slot reset
json_schema = json();
generated_tool_call_ids.clear();

Expand All @@ -1755,6 +1757,8 @@ struct server_slot {

// clear alora start
alora_invocation_start = -1;
n_prompt_checkpoints_made = 0;
last_prompt_checkpoint_token_count = 0; // Reset this too
}

bool need_embd() const {
Expand Down Expand Up @@ -1963,6 +1967,7 @@ struct server_slot {

return res;
}
int last_prompt_checkpoint_token_count = 0; // New: Tracks tokens processed at the last checkpoint
};

struct server_metrics {
Expand Down Expand Up @@ -3413,7 +3418,6 @@ struct server_context {
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
slot->prompt.tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
Expand Down Expand Up @@ -3924,6 +3928,7 @@ struct server_context {
}

bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
int n_prompt_checkpoints_target = params_base.n_ctx_checkpoints / 2;

// make checkpoints only for completion tasks
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
Expand All @@ -3940,6 +3945,42 @@ struct server_context {
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);

// Checkpoint during prompt processing
if (do_checkpoint && slot.state == SLOT_STATE_PROCESSING_PROMPT && n_prompt_checkpoints_target > 0) {
const int prompt_len_to_process = slot.task->n_tokens() - slot.n_prompt_tokens_cache;
if (prompt_len_to_process > 0) {
const int checkpoint_interval = std::max(64, prompt_len_to_process / n_prompt_checkpoints_target);
const int current_processed_tokens = slot.prompt.n_tokens() - slot.n_prompt_tokens_cache;

// Create a checkpoint if we've processed enough new tokens and haven't exceeded our target for prompt checkpoints
if (current_processed_tokens > 0 && current_processed_tokens - slot.last_prompt_checkpoint_token_count >= checkpoint_interval &&
slot.n_prompt_checkpoints_made < n_prompt_checkpoints_target)
{
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

// Ensure checkpoint is useful and not too close to previous one
bool can_create_new_checkpoint = (pos_min >= 0 && pos_max >= 64);
if (can_create_new_checkpoint && !slot.prompt.checkpoints.empty()) {
can_create_new_checkpoint = (pos_max > slot.prompt.checkpoints.back().pos_max + 64);
}

if (can_create_new_checkpoint) {
// Ensure there's room for the new checkpoint by removing the oldest if necessary
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{pos_min, pos_max, std::vector<uint8_t>(checkpoint_size)});
llama_state_seq_get_data_ext(ctx, slot.prompt.checkpoints.back().data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created prompt processing context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, pos_min, pos_max, (float) checkpoint_size / 1024 / 1024);
slot.n_prompt_checkpoints_made++;
}
}
}
}

// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
// get next token to process
Expand All @@ -3965,11 +4006,6 @@ struct server_context {
slot.prompt.tokens.push_back(cur_tok);

slot.n_prompt_tokens_processed++;

// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
break;
}
}

// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
Expand Down Expand Up @@ -4004,29 +4040,21 @@ struct server_context {
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
bool can_create_final_prompt_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);

// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
if (can_create_final_prompt_checkpoint && !slot.prompt.checkpoints.empty()) {
can_create_final_prompt_checkpoint = (pos_max > slot.prompt.checkpoints.back().pos_max + 64);
}

if (do_checkpoint) {
if (can_create_final_prompt_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();

SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}

const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{pos_min, pos_max, std::vector<uint8_t>(checkpoint_size)});

llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

Expand Down
Loading