Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 51 additions & 37 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ struct ring_buffer {
struct common_sampler {
common_params_sampling params;

struct llama_sampler * grmr;
struct llama_sampler * chain;

bool grammar;

ring_buffer<llama_token> prev;

std::vector<llama_token_data> cur;
Expand Down Expand Up @@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co

lparams.no_perf = params.no_perf;

llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);

bool grammar = false;
std::vector<llama_sampler *> samplers;

if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
grammar = true;
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
Expand Down Expand Up @@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co

if (!params.grammar.empty()) {
if (params.grammar_lazy) {
samplers.push_back(
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()));
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size());
} else {
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}

grammar = true;
}
}

Expand Down Expand Up @@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co

auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
Expand All @@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co

void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);

delete gsmpl;
Expand All @@ -324,25 +320,12 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();

if (gsmpl->grammar) {
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);

for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);

// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token);
if (gsmpl->grmr && accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}

llama_sampler_accept(gsmpl->chain, token);

gsmpl->prev.push_back(token);
}

Expand All @@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
Expand Down Expand Up @@ -410,19 +393,50 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}

llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);

// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();

llama_token id = LLAMA_TOKEN_NULL;

auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits

gsmpl->set_logits(ctx, idx);

if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}

llama_sampler_apply(chain, &cur_p);

id = cur_p.data[cur_p.selected].id;

if (grammar_first) {
return id;
}

// check if it the sampled token fits the grammar (grammar-based rejection sampling)
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };

llama_sampler_apply(grmr, &single_token_data_array);

const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}

// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);

llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
Expand All @@ -432,15 +446,15 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

Expand All @@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
}

if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

Expand All @@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
Expand Down
9 changes: 6 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// generalized version of common_sampler_sample
//
Expand All @@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

Expand Down
2 changes: 1 addition & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);

common_sampler_sample(smpl, ctx_dft, 0);
common_sampler_sample(smpl, ctx_dft, 0, true);

const auto * cur_p = common_sampler_get_candidates(smpl, true);

Expand Down
4 changes: 2 additions & 2 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sampling.temp > 0) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);

auto & dist_tgt = *common_sampler_get_candidates(smpl, true);

Expand Down Expand Up @@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue;
}

common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);

const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

Expand Down
16 changes: 16 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
}

void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
if (!smpl) {
return;
}

if (smpl->iface->accept) {
smpl->iface->accept(smpl, token);
}
}

void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
if (!smpl) {
return;
}

GGML_ASSERT(smpl->iface->apply);
smpl->iface->apply(smpl, cur_p);
}

void llama_sampler_reset(struct llama_sampler * smpl) {
if (!smpl) {
return;
}

if (smpl->iface->reset) {
smpl->iface->reset(smpl);
}
}

struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
if (!smpl) {
return nullptr;
}

if (smpl->iface->clone) {
return smpl->iface->clone(smpl);
}
Expand Down