Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 92 additions & 72 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ struct server_slot {
int32_t n_remaining = -1;
int32_t i_batch = -1;

std::vector<int32_t> i_batch_dft; // idx of draft tokens in the main batch

int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0;

Expand Down Expand Up @@ -149,7 +151,8 @@ struct server_slot {

struct common_sampler * smpl = nullptr;

llama_token sampled;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;

// stats
size_t n_sent_text = 0; // number of sent text character
Expand Down Expand Up @@ -179,6 +182,8 @@ struct server_slot {
stopping_word = "";
n_sent_text = 0;

drafted.clear();
i_batch_dft.clear();
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
Expand Down Expand Up @@ -254,6 +259,31 @@ struct server_slot {
generated_token_probs.push_back(token);
}

int get_n_draft_max() const {
if (!can_speculate()) {
return 0;
}

// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);

if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
}

SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}

void release() {
if (is_processing()) {
GGML_ASSERT(task);
Expand Down Expand Up @@ -1745,14 +1775,54 @@ struct server_context_impl {
continue;
}

slot.i_batch = batch.n_tokens;
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);

if (slot.task->params.speculative.n_min > (int) draft.size()) {
// ignore small drafts
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);

} else {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);

// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
} else {
// no speculative decoding
slot.i_batch = batch.n_tokens;

common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);

slot.prompt.tokens.push_back(slot.sampled);
slot.prompt.tokens.push_back(slot.sampled);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
}

// process in chunks of params.n_batch
Expand Down Expand Up @@ -2341,6 +2411,10 @@ struct server_context_impl {
continue; // continue loop of slots
}

if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}

const int tok_idx = slot.i_batch - i;

llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
Expand Down Expand Up @@ -2381,84 +2455,30 @@ struct server_context_impl {
}
}

// do speculative decoding
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
continue;
}

if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

// determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < slot.task->params.speculative.n_min) {
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);

continue;
}

llama_token id = slot.sampled;

struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;

const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);

// ignore small drafts
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);

continue;
}

// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
}

SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

llama_decode(ctx, slot.batch_spec);
size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();

slot.n_decoded += ids.size();

// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;

slot.prompt.tokens.push_back(id);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);

// add accepted tokens to the prompt
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);

Expand All @@ -2481,7 +2501,7 @@ struct server_context_impl {
}
}

SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
}
}

Expand Down
Loading