Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:
Expand Down Expand Up @@ -423,7 +423,7 @@ static void sampler_queue(
}
if (use_adaptive_p) {
// adaptive p should be put to the last, so we ignore the order in the sampler
llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p);
llama_sample_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
}
}

Expand Down Expand Up @@ -471,15 +471,9 @@ static llama_token llama_sampling_sample_impl(
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) {
// adaptive p sampling
static thread_local std::vector<float> orig_probs;
orig_probs.resize(cur_p.size);

// store original probabilities
for (size_t ii = 0; ii < cur_p.size; ++ii) {
orig_probs[ii] = cur_p.data[ii].p;
}
llama_prep_adaptive_p(&cur_p, ctx_sampling->adapt_p_ctx);
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data());
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);
Expand Down
15 changes: 9 additions & 6 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1383,16 +1383,20 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(
const float target,
const float decay,
const uint32_t seed);

void llama_prep_adaptive_p(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
void llama_sample_adaptive_p(
struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);


/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
Expand Down Expand Up @@ -1436,8 +1440,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs);
struct llama_sampler_adaptive_p * adapt_p_ctx);

//
// Model split
Expand Down
98 changes: 71 additions & 27 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,41 +1038,47 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();

const size_t count = candidates->size;
adapt_p_ctx->probs.resize(count);
struct llama_sampler_adaptive_p * ctx = adapt_p_ctx;
ctx->cum_probs.resize(candidates->size);

// cumulative distribution
const float max_logit = adapt_p_ctx->max_logit;
// compute cumulative probability distribution
const float max_logit = ctx->max_xform_logit;
float cum_prob = 0.0f;
for (size_t i = 0; i < count; ++i) {
for (size_t i = 0; i < candidates->size; ++i) {
cum_prob += expf(candidates->data[i].logit - max_logit);
adapt_p_ctx->probs[i] = cum_prob;
ctx->cum_probs[i] = cum_prob;
}
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()
ctx->cum_probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()

// find token with cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;
// select first token whose cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)ctx->rng() / (float)ctx->rng.max();
auto iter = std::upper_bound(ctx->cum_probs.begin(), ctx->cum_probs.end(), target_cum_prob);
GGML_ASSERT(iter != ctx->cum_probs.end());
const size_t idx = std::distance(ctx->cum_probs.begin(), iter);
llama_token id = candidates->data[idx].id;

smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;

float update_prob = candidates->data[idx].p; // not ideal
if (ctx->orig_prob_map.contains(id)) {
// selected token id is among tracked ids
update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob;
}

// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;

return id;
}

void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx)
{
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
Expand All @@ -1082,14 +1088,16 @@ void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ct

// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
if (!candidates->sorted) {
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
}
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
const float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
const float prob = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = prob;
cum_sum += prob;
}

// compute adapted target probability
Expand Down Expand Up @@ -1117,10 +1125,45 @@ void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ct
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
adapt_p_ctx->max_xform_logit = max_logit;
}

void llama_prep_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
if (!candidates->sorted) {
std::sort(candidates->data, candidates->data + candidates->size,
[](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
candidates->sorted = true;
}
const float max_logit = candidates->data[0].logit;

// decide how many tokens to track based on logit delta
// i.e. do not track unlikely tokens
auto iter = std::lower_bound(
candidates->data,
candidates->data + candidates->size,
max_logit - 16.6f, // delta
[](const llama_token_data & data, const float delta) {
return data.logit > delta;
});
const size_t n_track = std::distance(candidates->data, iter);

// store orig_prob_map and cum_orig_prob to estimate original probability later
float cum_prob = 0.0f;
adapt_p_ctx->orig_prob_map.clear();
for (size_t i = 0; i < n_track; ++i) {
const float prob = expf(candidates->data[i].logit - max_logit);
cum_prob += prob;
adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob;
}
adapt_p_ctx->cum_orig_prob = cum_prob;
}

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
Expand All @@ -1132,12 +1175,13 @@ struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
/* .probs = */ {},
/* .orig_logit_map = */ {},
/* .cum_orig_prob = */ 0.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
};
}


// grammar

struct llama_sampler_grammar {
Expand Down
30 changes: 23 additions & 7 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(

void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);


// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
Expand All @@ -70,15 +71,30 @@ struct llama_sampler_adaptive_p {
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
std::vector<float> probs; // cumulative probabilities

// first referenced in prep
std::unordered_map<llama_token, float> orig_prob_map; // probabilities before sampler_queue
float cum_orig_prob; // for normalizing orig_prob in sample_token

// first referenced in sample
float max_xform_logit; // maximum logit found during transform

// first referenced in sample_token
std::vector<float> cum_probs; // cumulative probability distribution
};

void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed);

void llama_prep_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed);
void llama_sample_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);


void llama_sample_repetition_penalties_impl(
Expand All @@ -101,6 +117,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx);


22 changes: 14 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7677,11 +7677,18 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s

void llama_sample_adaptive_p(
[[maybe_unused]] struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates) {
llama_sampler_adaptive_p_apply(adapt_p_ctx, candidates);
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
llama_sample_adaptive_p_impl(candidates, adapt_p_ctx);
}

void llama_prep_adaptive_p(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx)
{
llama_prep_adaptive_p_impl(candidates, adapt_p_ctx);
}


void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
Expand Down Expand Up @@ -7724,10 +7731,9 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx, orig_probs);
return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}

int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
Expand Down Expand Up @@ -7782,9 +7788,9 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
}


struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(const float target, const float decay, const uint32_t seed)
struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed)
{
return llama_sampler_init_adaptive_p_impl(target, decay, seed);
return llama_init_adaptive_p_impl(target, decay, seed);
}


Expand Down