Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
break;
}
}

return result;
}

Expand Down Expand Up @@ -419,7 +419,7 @@ static void sampler_queue(
case llama_sampler_type::ADAPTIVE_P: use_adaptive_p = true; break;
default : break;
}

}
if (use_adaptive_p) {
// adaptive p should be put to the last, so we ignore the order in the sampler
Expand Down Expand Up @@ -451,7 +451,7 @@ static llama_token llama_sampling_sample_impl(
if (ctx_sampling->grammar != NULL && is_resampling) {
float* logits = llama_get_logits_ith(ctx_main, idx);
// Apply grammar constraints to all candidates
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
}

if (temp < 0.0) {
Expand All @@ -471,7 +471,7 @@ static llama_token llama_sampling_sample_impl(
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) {
// adaptive p sampling
llama_prep_adaptive_p(&cur_p, ctx_sampling->adapt_p_ctx);
llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
const float decay,
const uint32_t seed);

void llama_prep_adaptive_p(
void llama_prep_adaptive_p(struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

Expand Down
121 changes: 83 additions & 38 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,7 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
struct llama_sampler_adaptive_p * adapt_p_ctx) {
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();

Expand All @@ -1062,30 +1061,38 @@ llama_token llama_sample_token_adaptive_p_impl(
const size_t idx = std::distance(ctx->cum_probs.begin(), iter);
llama_token id = candidates->data[idx].id;

if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) {
float update_prob = it->second / ctx->cum_orig_prob;
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
}

smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;

float update_prob = candidates->data[idx].p; // not ideal
if (ctx->orig_prob_map.contains(id)) {
// selected token id is among tracked ids
update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob;
}
//float update_prob = candidates->data[idx].p; // not ideal
//if (ctx->orig_prob_map.contains(id)) {
// // selected token id is among tracked ids
// update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob;
//}

// update history with original probability of selected token
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
//// update history with original probability of selected token
//ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
//ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;

return id;
}

void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx)
{
void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
llama_sample_softmax_impl(nullptr, candidates);
return;
}

auto t_start = ggml_time_us();

// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
if (!candidates->sorted) {
Expand Down Expand Up @@ -1126,48 +1133,86 @@ void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct ll
}
candidates->sorted = false;
adapt_p_ctx->max_xform_logit = max_logit;

ctx->t_sample_us += ggml_time_us() - t_start;
}

void llama_prep_adaptive_p_impl(
void llama_prep_adaptive_p_impl(struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
struct llama_sampler_adaptive_p * adapt_p_ctx) {
constexpr float kDelta = 16.6f;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see good results with kDelta=11.5. High hit rate, minimal performance drop.

auto t_start = ggml_time_us();
if (!candidates->sorted) {
std::sort(candidates->data, candidates->data + candidates->size,
[](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
candidates->sorted = true;
float max_logit = candidates->data[0].logit;
for (int j = 1; j < int(candidates->size); ++j) {
max_logit = std::max(max_logit, candidates->data[j].logit);
}
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
adapt_p_ctx->orig_prob_map.clear();
for (int j = 0; j < int(candidates->size); ++j) {
if (candidates->data[j].logit > min_logit) {
float prob = expf(candidates->data[j].logit - max_logit);
cum_prob += prob;
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
}
}
adapt_p_ctx->cum_orig_prob = cum_prob;
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
return;
}
const float max_logit = candidates->data[0].logit;

// decide how many tokens to track based on logit delta
// i.e. do not track unlikely tokens
auto iter = std::lower_bound(
candidates->data,
candidates->data + candidates->size,
max_logit - 16.6f, // delta
[](const llama_token_data & data, const float delta) {
return data.logit > delta;
});
const size_t n_track = std::distance(candidates->data, iter);

// store orig_prob_map and cum_orig_prob to estimate original probability later
float max_logit = candidates->data[0].logit;
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
adapt_p_ctx->orig_prob_map.clear();
for (size_t i = 0; i < n_track; ++i) {
const float prob = expf(candidates->data[i].logit - max_logit);
for (int j = 0; j < int(candidates->size); ++j) {
auto logit = candidates->data[j].logit;
if (logit <= min_logit) {
break;
}
float prob = expf(logit - max_logit);
cum_prob += prob;
adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob;
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
}
adapt_p_ctx->cum_orig_prob = cum_prob;
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;

//if (!candidates->sorted) {
// std::sort(candidates->data, candidates->data + candidates->size,
// [](const llama_token_data & a, const llama_token_data & b) {
// return a.logit > b.logit;
// });
// candidates->sorted = true;
//}
//const float max_logit = candidates->data[0].logit;

//// decide how many tokens to track based on logit delta
//// i.e. do not track unlikely tokens
//auto iter = std::lower_bound(
// candidates->data,
// candidates->data + candidates->size,
// max_logit - kDelta, // delta
// [](const llama_token_data & data, const float delta) {
// return data.logit > delta;
// });
//const size_t n_track = std::distance(candidates->data, iter);

//// store orig_prob_map and cum_orig_prob to estimate original probability later
//float cum_prob = 0.0f;
//adapt_p_ctx->orig_prob_map.clear();
//for (size_t i = 0; i < n_track; ++i) {
// const float prob = expf(candidates->data[i].logit - max_logit);
// cum_prob += prob;
// adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob;
//}
//adapt_p_ctx->cum_orig_prob = cum_prob;
}

struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
{
const uint32_t seed) {
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
Expand Down
2 changes: 2 additions & 0 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
const uint32_t seed);

void llama_prep_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

void llama_sample_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

Expand Down
16 changes: 6 additions & 10 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7690,14 +7690,12 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s
void llama_sample_adaptive_p(
[[maybe_unused]] struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
llama_sample_adaptive_p_impl(candidates, adapt_p_ctx);
struct llama_sampler_adaptive_p * adapt_p_ctx) {
llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}

void llama_prep_adaptive_p(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx)
{
llama_prep_adaptive_p_impl(candidates, adapt_p_ctx);
void llama_prep_adaptive_p(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) {
llama_prep_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}


Expand Down Expand Up @@ -7743,8 +7741,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx)
{
struct llama_sampler_adaptive_p * adapt_p_ctx) {
return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}

Expand Down Expand Up @@ -7800,8 +7797,7 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
}


struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed)
{
struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) {
return llama_init_adaptive_p_impl(target, decay, seed);
}

Expand Down