Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
402204e
initial implementation of adaptive-p sampler
dungquixote42 Dec 28, 2025
05897f6
explicitly mark candidates unsorted + cleanup qualifiers
dungquixote42 Dec 29, 2025
9264476
cosmetic update
dungquixote42 Dec 29, 2025
6d50cd0
reorg prototypes
dungquixote42 Dec 30, 2025
7d51f81
lockstep with mainline
dungquixote42 Dec 30, 2025
55ff01b
add _impl for _init + reorg
dungquixote42 Dec 30, 2025
31ae3c2
add LLAMA_API to prototypes
dungquixote42 Dec 30, 2025
3b16355
update sharpness to 10
dungquixote42 Dec 30, 2025
29e7f3d
lockstep: rng seed
dungquixote42 Dec 31, 2025
34ca871
delete llama_sampling member in llama_sampler_adaptive_p
dungquixote42 Dec 31, 2025
ada31a4
fix LLAMA_API return type
dungquixote42 Dec 31, 2025
d0f030b
lockstep: rng seed cont
dungquixote42 Dec 31, 2025
46f70f6
actually correct implementation
dungquixote42 Dec 31, 2025
61eb7f6
lockstep: sorting behavior
dungquixote42 Dec 31, 2025
51acff4
const -> constexpr for known constants
dungquixote42 Jan 1, 2026
4607d0f
add missing space
dungquixote42 Jan 1, 2026
8b0361c
fix softmax usage in adaptive p sampler
dungquixote42 Jan 1, 2026
0ab5089
cosmetic changes
dungquixote42 Jan 1, 2026
0483601
implement do-not-sort version of softmax
dungquixote42 Jan 4, 2026
f0c2533
simpify rng seed, add static to constexpr
dungquixote42 Jan 7, 2026
dd30141
refactor: remove iface + use shared rng + use actually original proba…
dungquixote42 Jan 8, 2026
02f92e4
adaptive-p: add dedicated rng back in
dungquixote42 Jan 8, 2026
438e41c
fix initial max_logit + add float vector to adaptive p sampler contex…
dungquixote42 Jan 9, 2026
a7afa9a
adaptive-p: fuse first softmax with transformation
dungquixote42 Jan 9, 2026
6cf24a6
adaptive-p: implement binary search selection
dungquixote42 Jan 10, 2026
dcdd8ab
adaptive-p: update comment
dungquixote42 Jan 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--adaptive-target") {
CHECK_ARG
sparams.adaptive_target = std::stof(argv[i]);
return true;
}
if (arg == "--adaptive-decay") {
CHECK_ARG
sparams.adaptive_decay = std::stof(argv[i]);
return true;
}
if (arg == "--spec-replace") {
CHECK_ARG
std::string target = argv[i];
Expand Down Expand Up @@ -2201,6 +2211,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability });
options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold});
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
Expand Down Expand Up @@ -4174,6 +4186,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "adaptive_target: %f # default: -1.0\n", sparams.adaptive_target);
fprintf(stream, "adaptive_decay: %f # default: 0.9\n", sparams.adaptive_decay);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
36 changes: 30 additions & 6 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
result->n_valid = 0;
}
result->grammar = grmr;
// init DRY

for (const auto& cnstr : params.samplers_sequence)
{
switch (cnstr)
Expand All @@ -116,6 +116,11 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo

break;
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay);
break;
}
default:
break;
}
Expand Down Expand Up @@ -247,11 +252,13 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f",
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n"
"\tadaptive_target = %.2f, adaptive_decay = %.2f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.xtc_probability, params.xtc_threshold, params.top_n_sigma);
params.xtc_probability, params.xtc_threshold, params.top_n_sigma,
params.adaptive_target, params.adaptive_decay);

return std::string(result);
}
Expand Down Expand Up @@ -283,6 +290,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
case llama_sampler_type::ADAPTIVE_P : return "adaptive_p";
default : return "";
}
}
Expand All @@ -297,7 +305,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE}
{"temperature", llama_sampler_type::TEMPERATURE},
{"adaptive_p", llama_sampler_type::ADAPTIVE_P},
};

// since samplers names are written multiple ways
Expand All @@ -314,7 +323,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE}
{"temp", llama_sampler_type::TEMPERATURE},
{"adaptive-p", llama_sampler_type::ADAPTIVE_P},
};

std::vector<llama_sampler_type> sampler_types;
Expand Down Expand Up @@ -351,7 +361,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
{'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE}
{'t', llama_sampler_type::TEMPERATURE},
{'w', llama_sampler_type::ADAPTIVE_P},
};

std::vector<llama_sampler_type> sampler_types;
Expand Down Expand Up @@ -405,6 +416,7 @@ static void sampler_queue(
llama_sample_temp(ctx_main, &cur_p, temp);
}
break;
case llama_sampler_type::ADAPTIVE_P: llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p); break;
default : break;
}
}
Expand All @@ -422,6 +434,7 @@ static llama_token llama_sampling_sample_impl(
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const float adaptive_target = params.adaptive_target;

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
Expand All @@ -445,6 +458,17 @@ static llama_token llama_sampling_sample_impl(
} else if (mirostat == 2) {
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f) {
// adaptive p sampling
static std::vector<float> orig_probs;
orig_probs.reserve(cur_p.size);

// store original probabilities
for (size_t ii = 0; ii < cur_p.size; ++ii) {
orig_probs.emplace_back(cur_p.data[ii].p);
}
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data());
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);
Expand Down
10 changes: 8 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ enum class llama_sampler_type : char {
XTC = 'x',
TOP_N_SIGMA = 'n',
TYPICAL_P = 'y',
TEMPERATURE = 't'
TEMPERATURE = 't',
ADAPTIVE_P = 'w',
};

enum common_grammar_trigger_type {
Expand Down Expand Up @@ -66,6 +67,8 @@ typedef struct llama_sampling_params {
float xtc_probability = 0.0f; // xtc probability
float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5
float top_n_sigma = 0.0f; // top-n-sigma
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context

Expand All @@ -80,7 +83,8 @@ typedef struct llama_sampling_params {
llama_sampler_type::MIN_P,
llama_sampler_type::XTC,
llama_sampler_type::TOP_N_SIGMA,
llama_sampler_type::TEMPERATURE
llama_sampler_type::TEMPERATURE,
llama_sampler_type::ADAPTIVE_P,
};


Expand Down Expand Up @@ -118,6 +122,8 @@ struct llama_sampling_context {
std::vector<llama_token_data> cur;
llama_sampler_dry* smpl;

llama_sampler_adaptive_p * adapt_p_ctx; // adaptive p sampler

size_t n_valid; // Number of correct top tokens with correct probabilities.

llama_token_data_array cur_p; // current candidates
Expand Down
4 changes: 4 additions & 0 deletions examples/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
Expand Down Expand Up @@ -1405,6 +1407,8 @@ json server_context::get_formated_generation(const server_slot& slot) const {
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"adaptive_target", slot.sparams.adaptive_target},
{"adaptive_decay", slot.sparams.adaptive_decay},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
Expand Down
21 changes: 21 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,20 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982


/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(
const float target,
const float decay);

/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
void llama_sample_adaptive_p(
struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);


/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down Expand Up @@ -1417,6 +1431,13 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
struct llama_context * ctx,
llama_token_data_array * candidates);

/// @details Randonly selects a token from the candidates following adaptive p sampler.
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs);

//
// Model split
//
Expand Down
111 changes: 111 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
smpl->rng.seed(seed);
}

void llama_sample_softmax_nosort_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, const float * const max_logit)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();

float max_l = candidates->data[0].logit;
if (max_logit == nullptr) {
// maximum logit is not known
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
} else {
// maximum logit is known
max_l = *max_logit;
}

float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}

for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].p /= cum_sum;
}

if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}

void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);

Expand Down Expand Up @@ -1033,6 +1065,85 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
}


// adaptive p

llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
GGML_ASSERT(smpl);
const int64_t t_start_sample_us = ggml_time_us();

// softmax with known maximum logit
llama_sample_softmax_nosort_impl(nullptr, candidates, &(adapt_p_ctx->max_logit));
Copy link
Copy Markdown
Owner

@ikawrakow ikawrakow Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where/when is adapt_p_ctx->max_logit initialized to a meaningful value?

NVM, I saw it below.


// sample
std::vector<float> probs;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this vector be made a member of the adaptive sampler context? So that a new allocation for each new token is not required.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will make it a member in follow-up commits.

probs.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
probs.emplace_back(candidates->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
llama_token id = candidates->data[dist(smpl->rng)].id;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the logits have not been filtered to a relatively small number of candidates, this will be a fairly computationally expensive operation with typical vocabulary sizes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is basically copied from llama_sample_token_with_rng_impl, minus push_back vs emplace_back. Is emplace_back much slower than push_back, or did I miss something here?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK to merge it like this. But having done quite a bit of Monte Carlo in a previous life, I couldn't help myself but comment.

It is not the emplace_back() that is slow, but the overall implementation (and yes, I know, mainline's implementation also inherited here is far from ideal). We are basically going 3 times over the whole array of token probabilities, to then construct a std::discrete_distribution object, to get just a single random sample from that. If the candidates have been reduced to a relatively small number via top_k or min_p or similar, this is fine. But if we are going over the entire vocabulary of ~200k tokens, this is going to add a noticeable extra time relative to say, 100 t/s generation speed. My guess is that the best thing to do would be to just compute the cumulative probability distribution on-the-fly, and then use binary search to find the candidate given a random number between 0 and 1 multiplied with the last element of the cumulative distribution.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't mind updating the mainline implementation as well, as long as the distribution modification doesn't affect the result

The mainline implementation inherited a lot of inefficiency due to my own personal choice in models + hardware rarely exceeding ~5t/s. At those speeds, any optimization is a micro-optimization.

I'm having a difficult time visualizing your suggestion though.

Copy link
Copy Markdown
Contributor Author

@dungquixote42 dungquixote42 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    // first cum_prob is spacer
    const size_t count = candidates->size + 1;
    adapt_p_ctx->probs.reserve(count);`

    // cumulative distribution
    const float max_logit = adapt_p_ctx->max_logit;
    float cum_prob = 0.0f;
    for (size_t i = 0; i < count; ++i) {
        adapt_p_ctx->probs.emplace_back(cum_prob);
        cum_prob += expf(candidates->data[i].logit - max_logit);
    }
    const float target_cprob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();

    // my binary search
    bool done = false;
    size_t idx = (count >> 1) + 1;
    size_t stride = (count >> 1) + 1;
    while (!done) {
        stride = (stride >> 1) + 1;
        const float cprob = adapt_p_ctx->probs[idx];
        if (target_cprob > cprob) {
            idx += stride;
        }
        else if (target_cprob < cprob - adapt_p_ctx->probs[idx - 1]) {
            idx -= stride;
        }
        else {
            done = true;
        }
    }

    // ai slop
    auto it = std::lower_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cprob);
    size_t idx = std::distance(adapt_p_ctx->probs.begin(), it) - 2;

    llama_token id = candidates->data[idx].id;

It does not work yet, but is this what you had in mind?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something along these lines.
I think, this should work:

    // first cum_prob is spacer
    adapt_p_ctx->probs.reserve(candidates->size);

    // cumulative distribution
    const float max_logit = adapt_p_ctx->max_logit;
    float cum_prob = 0.0f;
    for (size_t i = 0; i < count; ++i) {
        cum_prob += expf(candidates->data[i].logit - max_logit);
        adapt_p_ctx->probs.emplace_back(cum_prob); // note: we emplace **after** adding the current probability 
    }
    // add a safety to the last element just to be sure we avoid numerical issues when the random
    // number is (nearly) at maximum.
    adapt_p_ctx->probs.back() += 1.0f;
    const float target_cprob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
    auto it = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_prob);
    GGML_ASSERT(it != adapt_p_ctx->probs.end());
    llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), it);


smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;

// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;

return id;
}

void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
{
llama_sample_softmax_nosort_impl(nullptr, candidates, nullptr);
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
return;
}

// compute adapted target probability
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
0.0f, 1.0f);

// transformation constants
static constexpr float peak_logit_value = 5.0f;
static constexpr float inv_width = 1.0f / 0.3f;
static constexpr float sharpness = 10.0f;

// quadratic near target for finite differentiation, transitioning to linear decay in tails
// unbounded negative logits suppress far-from-target tokens after softmax
float max_logit = std::numeric_limits<float>::min();
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the intent here to use the minimum float value that is not zero (the value of std::numeric_limits<float>::min() = 1.17549e-38 ) or perhaps more something like -INFINITY ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know -INFINITY was a thing. Heh. Google showed mestd::numeric_limits<float>::min(), and I said lgtm. Follow-up commits will have -INFINITY.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I meant to use -INFINITY. C dev here. C++ noob.

for (size_t i = 0; i < candidates->size; ++i) {
const float dist = std::abs((candidates->data[i].p - adapted_target) * inv_width);
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
candidates->data[i].logit = logit;
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
}

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay)
{
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
};
}


// grammar

struct llama_sampler_grammar {
Expand Down
20 changes: 19 additions & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct llama_sampling {

void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);

void llama_sample_softmax_nosort_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, const float * const max_logit);

void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
Expand Down Expand Up @@ -61,6 +63,22 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(

void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);

// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
};

void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay);


void llama_sample_repetition_penalties_impl(
Expand All @@ -83,6 +101,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);

llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);


Loading