diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb60..b548910fa8b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1669,6 +1669,7 @@ struct server_slot { std::string stopping_word; // state + int n_prompt_checkpoints_made = 0; // New: Number of checkpoints made during prompt processing slot_state state = SLOT_STATE_IDLE; server_prompt prompt; @@ -1743,6 +1744,7 @@ struct server_slot { generated_tokens.clear(); generated_token_probs.clear(); chat_msg = {}; + n_prompt_checkpoints_made = 0; // New: Reset on slot reset json_schema = json(); generated_tool_call_ids.clear(); @@ -1755,6 +1757,8 @@ struct server_slot { // clear alora start alora_invocation_start = -1; + n_prompt_checkpoints_made = 0; + last_prompt_checkpoint_token_count = 0; // Reset this too } bool need_embd() const { @@ -1963,6 +1967,7 @@ struct server_slot { return res; } + int last_prompt_checkpoint_token_count = 0; // New: Tracks tokens processed at the last checkpoint }; struct server_metrics { @@ -3413,7 +3418,6 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } @@ -3924,6 +3928,7 @@ struct server_context { } bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + int n_prompt_checkpoints_target = params_base.n_ctx_checkpoints / 2; // make checkpoints only for completion tasks do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; @@ -3940,6 +3945,42 @@ struct server_context { (llama_model_n_swa(model) > 0 && !params_base.swa_full) ); + // Checkpoint during prompt processing + if (do_checkpoint && slot.state == SLOT_STATE_PROCESSING_PROMPT && n_prompt_checkpoints_target > 0) { + const int prompt_len_to_process = slot.task->n_tokens() - slot.n_prompt_tokens_cache; + if (prompt_len_to_process > 0) { + const int checkpoint_interval = std::max(64, prompt_len_to_process / n_prompt_checkpoints_target); + const int current_processed_tokens = slot.prompt.n_tokens() - slot.n_prompt_tokens_cache; + + // Create a checkpoint if we've processed enough new tokens and haven't exceeded our target for prompt checkpoints + if (current_processed_tokens > 0 && current_processed_tokens - slot.last_prompt_checkpoint_token_count >= checkpoint_interval && + slot.n_prompt_checkpoints_made < n_prompt_checkpoints_target) + { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // Ensure checkpoint is useful and not too close to previous one + bool can_create_new_checkpoint = (pos_min >= 0 && pos_max >= 64); + if (can_create_new_checkpoint && !slot.prompt.checkpoints.empty()) { + can_create_new_checkpoint = (pos_max > slot.prompt.checkpoints.back().pos_max + 64); + } + + if (can_create_new_checkpoint) { + // Ensure there's room for the new checkpoint by removing the oldest if necessary + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{pos_min, pos_max, std::vector(checkpoint_size)}); + llama_state_seq_get_data_ext(ctx, slot.prompt.checkpoints.back().data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + SLT_WRN(slot, "created prompt processing context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, pos_min, pos_max, (float) checkpoint_size / 1024 / 1024); + slot.n_prompt_checkpoints_made++; + } + } + } + } + // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -3965,11 +4006,6 @@ struct server_context { slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; - - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { - break; - } } // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); @@ -4004,29 +4040,21 @@ struct server_context { const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + bool can_create_final_prompt_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + if (can_create_final_prompt_checkpoint && !slot.prompt.checkpoints.empty()) { + can_create_final_prompt_checkpoint = (pos_max > slot.prompt.checkpoints.back().pos_max + 64); + } - if (do_checkpoint) { + if (can_create_final_prompt_checkpoint) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), - }); + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{pos_min, pos_max, std::vector(checkpoint_size)}); llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);