Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
402204e
initial implementation of adaptive-p sampler
dungquixote42 Dec 28, 2025
05897f6
explicitly mark candidates unsorted + cleanup qualifiers
dungquixote42 Dec 29, 2025
9264476
cosmetic update
dungquixote42 Dec 29, 2025
6d50cd0
reorg prototypes
dungquixote42 Dec 30, 2025
7d51f81
lockstep with mainline
dungquixote42 Dec 30, 2025
55ff01b
add _impl for _init + reorg
dungquixote42 Dec 30, 2025
31ae3c2
add LLAMA_API to prototypes
dungquixote42 Dec 30, 2025
3b16355
update sharpness to 10
dungquixote42 Dec 30, 2025
29e7f3d
lockstep: rng seed
dungquixote42 Dec 31, 2025
34ca871
delete llama_sampling member in llama_sampler_adaptive_p
dungquixote42 Dec 31, 2025
ada31a4
fix LLAMA_API return type
dungquixote42 Dec 31, 2025
d0f030b
lockstep: rng seed cont
dungquixote42 Dec 31, 2025
46f70f6
actually correct implementation
dungquixote42 Dec 31, 2025
61eb7f6
lockstep: sorting behavior
dungquixote42 Dec 31, 2025
51acff4
const -> constexpr for known constants
dungquixote42 Jan 1, 2026
4607d0f
add missing space
dungquixote42 Jan 1, 2026
8b0361c
fix softmax usage in adaptive p sampler
dungquixote42 Jan 1, 2026
0ab5089
cosmetic changes
dungquixote42 Jan 1, 2026
0483601
implement do-not-sort version of softmax
dungquixote42 Jan 4, 2026
f0c2533
simpify rng seed, add static to constexpr
dungquixote42 Jan 7, 2026
dd30141
refactor: remove iface + use shared rng + use actually original proba…
dungquixote42 Jan 8, 2026
02f92e4
adaptive-p: add dedicated rng back in
dungquixote42 Jan 8, 2026
438e41c
fix initial max_logit + add float vector to adaptive p sampler contex…
dungquixote42 Jan 9, 2026
a7afa9a
adaptive-p: fuse first softmax with transformation
dungquixote42 Jan 9, 2026
6cf24a6
adaptive-p: implement binary search selection
dungquixote42 Jan 10, 2026
dcdd8ab
adaptive-p: update comment
dungquixote42 Jan 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--adaptive-target") {
CHECK_ARG
sparams.adaptive_target = std::stof(argv[i]);
return true;
}
if (arg == "--adaptive-decay") {
CHECK_ARG
sparams.adaptive_decay = std::stof(argv[i]);
return true;
}
if (arg == "--spec-replace") {
CHECK_ARG
std::string target = argv[i];
Expand Down Expand Up @@ -2201,6 +2211,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability });
options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold});
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
Expand Down Expand Up @@ -4174,6 +4186,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "adaptive_target: %f # default: -1.0\n", sparams.adaptive_target);
fprintf(stream, "adaptive_decay: %f # default: 0.9\n", sparams.adaptive_decay);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
38 changes: 31 additions & 7 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
result->n_valid = 0;
}
result->grammar = grmr;
// init DRY
llama_sampling_set_rng_seed(result, params.seed);
for (const auto& cnstr : params.samplers_sequence)
{
switch (cnstr)
Expand All @@ -116,11 +116,16 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo

break;
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:
break;
}
}
llama_sampling_set_rng_seed(result, params.seed);

return result;
}

Expand Down Expand Up @@ -247,11 +252,13 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f",
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n"
"\tadaptive_target = %.2f, adaptive_decay = %.2f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.xtc_probability, params.xtc_threshold, params.top_n_sigma);
params.xtc_probability, params.xtc_threshold, params.top_n_sigma,
params.adaptive_target, params.adaptive_decay);

return std::string(result);
}
Expand Down Expand Up @@ -283,6 +290,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
case llama_sampler_type::ADAPTIVE_P : return "adaptive_p";
default : return "";
}
}
Expand All @@ -297,7 +305,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE}
{"temperature", llama_sampler_type::TEMPERATURE},
{"adaptive_p", llama_sampler_type::ADAPTIVE_P},
};

// since samplers names are written multiple ways
Expand All @@ -314,7 +323,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE}
{"temp", llama_sampler_type::TEMPERATURE},
{"adaptive-p", llama_sampler_type::ADAPTIVE_P},
};

std::vector<llama_sampler_type> sampler_types;
Expand Down Expand Up @@ -351,7 +361,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
{'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE}
{'t', llama_sampler_type::TEMPERATURE},
{'w', llama_sampler_type::ADAPTIVE_P},
};

std::vector<llama_sampler_type> sampler_types;
Expand Down Expand Up @@ -405,6 +416,7 @@ static void sampler_queue(
llama_sample_temp(ctx_main, &cur_p, temp);
}
break;
case llama_sampler_type::ADAPTIVE_P: llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p); break;
default : break;
}
}
Expand All @@ -422,6 +434,7 @@ static llama_token llama_sampling_sample_impl(
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const float adaptive_target = params.adaptive_target;

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
Expand All @@ -445,6 +458,17 @@ static llama_token llama_sampling_sample_impl(
} else if (mirostat == 2) {
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f) {
// adaptive p sampling
static thread_local std::vector<float> orig_probs;
orig_probs.resize(cur_p.size);

// store original probabilities
for (size_t ii = 0; ii < cur_p.size; ++ii) {
orig_probs[ii] = cur_p.data[ii].p;
}
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data());
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);
Expand Down
10 changes: 8 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ enum class llama_sampler_type : char {
XTC = 'x',
TOP_N_SIGMA = 'n',
TYPICAL_P = 'y',
TEMPERATURE = 't'
TEMPERATURE = 't',
ADAPTIVE_P = 'w',
};

enum common_grammar_trigger_type {
Expand Down Expand Up @@ -66,6 +67,8 @@ typedef struct llama_sampling_params {
float xtc_probability = 0.0f; // xtc probability
float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5
float top_n_sigma = 0.0f; // top-n-sigma
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context

Expand All @@ -80,7 +83,8 @@ typedef struct llama_sampling_params {
llama_sampler_type::MIN_P,
llama_sampler_type::XTC,
llama_sampler_type::TOP_N_SIGMA,
llama_sampler_type::TEMPERATURE
llama_sampler_type::TEMPERATURE,
llama_sampler_type::ADAPTIVE_P,
};


Expand Down Expand Up @@ -118,6 +122,8 @@ struct llama_sampling_context {
std::vector<llama_token_data> cur;
llama_sampler_dry* smpl;

llama_sampler_adaptive_p * adapt_p_ctx; // adaptive p sampler

size_t n_valid; // Number of correct top tokens with correct probabilities.

llama_token_data_array cur_p; // current candidates
Expand Down
4 changes: 4 additions & 0 deletions examples/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
Expand Down Expand Up @@ -1405,6 +1407,8 @@ json server_context::get_formated_generation(const server_slot& slot) const {
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"adaptive_target", slot.sparams.adaptive_target},
{"adaptive_decay", slot.sparams.adaptive_decay},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
Expand Down
22 changes: 22 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,21 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982


/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(
const float target,
const float decay,
const uint32_t seed);

/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
void llama_sample_adaptive_p(
struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);


/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down Expand Up @@ -1417,6 +1432,13 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
struct llama_context * ctx,
llama_token_data_array * candidates);

/// @details Randonly selects a token from the candidates following adaptive p sampler.
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs);

//
// Model split
//
Expand Down
105 changes: 105 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
}


// adaptive p

llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();

const size_t count = candidates->size;
adapt_p_ctx->probs.resize(count);

// cumulative distribution
const float max_logit = adapt_p_ctx->max_logit;
float cum_prob = 0.0f;
for (size_t i = 0; i < count; ++i) {
cum_prob += expf(candidates->data[i].logit - max_logit);
adapt_p_ctx->probs[i] = cum_prob;
}
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()

// find token with cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;

smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;

// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;

return id;
}

void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
{
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
llama_sample_softmax_impl(nullptr, candidates);
return;
}

// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
const float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}

// compute adapted target probability
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
0.0f, 1.0f);

// transformation constants
static constexpr float peak_logit_value = 5.0f;
static constexpr float inv_width = 1.0f / 0.3f;
static constexpr float sharpness = 10.0f;

const float fused_target = adapted_target * inv_width;
const float fused_width = inv_width / cum_sum;

// quadratic near target for finite differentiation, transitioning to linear decay in tails
// unbounded negative logits suppress far-from-target tokens after softmax
float max_logit = -INFINITY;
for (size_t i = 0; i < candidates->size; ++i) {
const float dist = std::abs(candidates->data[i].p * fused_width - fused_target);
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
candidates->data[i].logit = logit;
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
}

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
{
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
/* .probs = */ {},
};
}


// grammar

struct llama_sampler_grammar {
Expand Down
20 changes: 19 additions & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(

void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);

// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
std::vector<float> probs; // cumulative probabilities
};

void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);

struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed);


void llama_sample_repetition_penalties_impl(
Expand All @@ -83,6 +101,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);

llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);


Loading