-
Notifications
You must be signed in to change notification settings - Fork 251
Implement Adaptive-P Sampler #1100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
402204e
05897f6
9264476
6d50cd0
7d51f81
55ff01b
31ae3c2
3b16355
29e7f3d
34ca871
ada31a4
d0f030b
46f70f6
61eb7f6
51acff4
4607d0f
8b0361c
0ab5089
0483601
f0c2533
dd30141
02f92e4
438e41c
a7afa9a
6cf24a6
dcdd8ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,38 @@ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { | |
| smpl->rng.seed(seed); | ||
| } | ||
|
|
||
| void llama_sample_softmax_nosort_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, const float * const max_logit) | ||
| { | ||
| GGML_ASSERT(candidates->size > 0); | ||
| const int64_t t_start_sample_us = ggml_time_us(); | ||
|
|
||
| float max_l = candidates->data[0].logit; | ||
| if (max_logit == nullptr) { | ||
| // maximum logit is not known | ||
| for (size_t i = 1; i < candidates->size; ++i) { | ||
| max_l = std::max(max_l, candidates->data[i].logit); | ||
| } | ||
| } else { | ||
| // maximum logit is known | ||
| max_l = *max_logit; | ||
| } | ||
|
|
||
| float cum_sum = 0.0f; | ||
| for (size_t i = 0; i < candidates->size; ++i) { | ||
| float p = expf(candidates->data[i].logit - max_l); | ||
| candidates->data[i].p = p; | ||
| cum_sum += p; | ||
| } | ||
|
|
||
| for (size_t i = 0; i < candidates->size; ++i) { | ||
| candidates->data[i].p /= cum_sum; | ||
| } | ||
|
|
||
| if (smpl) { | ||
| smpl->t_sample_us += ggml_time_us() - t_start_sample_us; | ||
| } | ||
| } | ||
|
|
||
| void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { | ||
| GGML_ASSERT(candidates->size > 0); | ||
|
|
||
|
|
@@ -1033,6 +1065,85 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& | |
| } | ||
|
|
||
|
|
||
| // adaptive p | ||
|
|
||
| llama_token llama_sample_token_adaptive_p_impl( | ||
| struct llama_sampling * smpl, | ||
| llama_token_data_array * candidates, | ||
| struct llama_sampler_adaptive_p * adapt_p_ctx, | ||
| float * orig_probs) | ||
| { | ||
| GGML_ASSERT(smpl); | ||
| const int64_t t_start_sample_us = ggml_time_us(); | ||
|
|
||
| // softmax with known maximum logit | ||
| llama_sample_softmax_nosort_impl(nullptr, candidates, &(adapt_p_ctx->max_logit)); | ||
|
|
||
| // sample | ||
| std::vector<float> probs; | ||
|
||
| probs.reserve(candidates->size); | ||
| for (size_t i = 0; i < candidates->size; ++i) { | ||
| probs.emplace_back(candidates->data[i].p); | ||
| } | ||
| std::discrete_distribution<> dist(probs.begin(), probs.end()); | ||
| llama_token id = candidates->data[dist(smpl->rng)].id; | ||
|
||
|
|
||
| smpl->t_sample_us += ggml_time_us() - t_start_sample_us; | ||
| smpl->n_sample++; | ||
|
|
||
| // update history with original probability of selected token | ||
| adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id]; | ||
| adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f; | ||
|
|
||
| return id; | ||
| } | ||
|
|
||
| void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates) | ||
| { | ||
| llama_sample_softmax_nosort_impl(nullptr, candidates, nullptr); | ||
| if (adapt_p_ctx->target < 0.0f) { | ||
| // sampler is disabled | ||
| return; | ||
| } | ||
|
|
||
| // compute adapted target probability | ||
| const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f); | ||
| const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f | ||
| ? target | ||
| : 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight), | ||
| 0.0f, 1.0f); | ||
|
|
||
| // transformation constants | ||
| static constexpr float peak_logit_value = 5.0f; | ||
| static constexpr float inv_width = 1.0f / 0.3f; | ||
| static constexpr float sharpness = 10.0f; | ||
|
|
||
| // quadratic near target for finite differentiation, transitioning to linear decay in tails | ||
| // unbounded negative logits suppress far-from-target tokens after softmax | ||
| float max_logit = std::numeric_limits<float>::min(); | ||
|
||
| for (size_t i = 0; i < candidates->size; ++i) { | ||
| const float dist = std::abs((candidates->data[i].p - adapted_target) * inv_width); | ||
| const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist); | ||
| candidates->data[i].logit = logit; | ||
| max_logit = std::max(max_logit, logit); | ||
| } | ||
| candidates->sorted = false; | ||
| adapt_p_ctx->max_logit = max_logit; | ||
| } | ||
|
|
||
| struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay) | ||
| { | ||
| const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); | ||
| return new llama_sampler_adaptive_p { | ||
| /* .target = */ target, | ||
| /* .decay = */ clamped_decay, | ||
| /* .weighted_sum = */ target / (1.0f - clamped_decay), | ||
| /* .total_weight = */ 1.0f / (1.0f - clamped_decay), | ||
| /* .max_logit = */ 0.0f, | ||
| }; | ||
| } | ||
|
|
||
|
|
||
| // grammar | ||
|
|
||
| struct llama_sampler_grammar { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where/when is
adapt_p_ctx->max_logitinitialized to a meaningful value?NVM, I saw it below.