-
Notifications
You must be signed in to change notification settings - Fork 16.2k
implement adaptive-p sampler #17927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement adaptive-p sampler #17927
Changes from 15 commits
774cf23
5ab4ff7
66e2d17
88fb0f3
374bfd4
ffe1639
4959878
f3457a8
9316959
b3aea57
cd7de7c
534cb4f
dcada03
2d62bbe
5c78b79
53380c1
94cb883
0a19a3f
824bb3a
1879fc6
67a7336
a96ddd7
b8a9626
965bcc9
d1e5c60
9613c48
2a3f579
ec54fe5
667b70f
36b526d
6934780
f5d0872
493bf30
6854325
b5ed673
4e28eb2
1c58e9a
4e04bd1
6e66095
9c50b57
0344068
1c2d2e9
85b6e52
fcb5129
58aa1c6
27dda80
7752998
6023572
dedbe36
f4703d4
89ebdf0
55ad4a8
6bad4ae
295d1d8
ed2890e
51070e0
90f3bfb
b95b088
f0d3f13
e7a8920
05d7dc9
2d67b1c
c6a6f63
0807499
eb854e7
55757dc
660a3b2
7173e84
c27df51
5fdc530
0400611
684c5ff
7ffd3a8
f48413c
bef75d9
8b1292a
e99a4a6
af0596c
5f04265
7f40928
3aa23f3
1eff502
d21c87e
4b92e3a
33c635e
4b06e08
42af39d
81af54c
40fd48f
b6041b1
f222e17
d7e3b86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2313,6 +2313,144 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa | |||||||||||||||
| return result; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // power-law | ||||||||||||||||
| // | ||||||||||||||||
| // this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID | ||||||||||||||||
| // rather than just transforming logits. therefore it must always be the last sampler in the | ||||||||||||||||
| // sampler chain. | ||||||||||||||||
| // | ||||||||||||||||
| // it is recommended to only perform minimal truncation before this sampler. | ||||||||||||||||
| // | ||||||||||||||||
| // ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation) | ||||||||||||||||
| // ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR) | ||||||||||||||||
|
|
||||||||||||||||
| struct llama_sampler_power_law { | ||||||||||||||||
| const float target; | ||||||||||||||||
| const int32_t window_size; | ||||||||||||||||
|
|
||||||||||||||||
| const uint32_t seed; | ||||||||||||||||
| std::mt19937 rng; | ||||||||||||||||
| ring_buffer<float> window; | ||||||||||||||||
| }; | ||||||||||||||||
|
|
||||||||||||||||
| static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) { | ||||||||||||||||
| return "power-law"; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||||||||||||||||
| auto * ctx = (llama_sampler_power_law *) smpl->ctx; | ||||||||||||||||
|
|
||||||||||||||||
| if (ctx->target < 0.0f) { | ||||||||||||||||
| // no-op: just sample from the distribution as-is | ||||||||||||||||
| llama_sampler_softmax_impl(cur_p, false); | ||||||||||||||||
| const int idx = llama_sample_dist(cur_p, ctx->rng); | ||||||||||||||||
| cur_p->selected = idx; | ||||||||||||||||
| return; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // fixed power law transform parameters (from original implementation) | ||||||||||||||||
| const float distribution_width = 0.2f; | ||||||||||||||||
| const float peak_logit_value = 3.0f; | ||||||||||||||||
| const float tail_heaviness = 3.0f; | ||||||||||||||||
|
|
||||||||||||||||
| // compute probabilities to get the "original" values | ||||||||||||||||
| llama_sampler_softmax_impl(cur_p, false); | ||||||||||||||||
|
|
||||||||||||||||
| // store original probabilities (used for future target adaptation) | ||||||||||||||||
| std::vector<float> original_probs; | ||||||||||||||||
| original_probs.reserve(cur_p->size); | ||||||||||||||||
| for (size_t i = 0; i < cur_p->size; ++i) { | ||||||||||||||||
| original_probs.push_back(cur_p->data[i].p); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // | ||||||||||||||||
| // calculate adaptive target | ||||||||||||||||
| // | ||||||||||||||||
|
|
||||||||||||||||
| const float min_target = 0.0f; | ||||||||||||||||
| const float max_target = 1.0f; | ||||||||||||||||
|
|
||||||||||||||||
| float computed_target = ctx->target; | ||||||||||||||||
| if (ctx->window.size() > 0) { | ||||||||||||||||
| float sum_excluding_oldest = 0.0f; | ||||||||||||||||
| size_t sz = ctx->window.size(); | ||||||||||||||||
|
|
||||||||||||||||
| // sum all except the oldest element | ||||||||||||||||
| for (size_t i = 0; i < sz - 1; ++i) { | ||||||||||||||||
| sum_excluding_oldest += ctx->window.rat(i); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| float next_value = (ctx->target * ctx->window_size) - sum_excluding_oldest; | ||||||||||||||||
| computed_target = std::max(min_target, std::min(next_value, max_target)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // | ||||||||||||||||
| // power law transform | ||||||||||||||||
| // | ||||||||||||||||
|
|
||||||||||||||||
| for (size_t i = 0; i < cur_p->size; ++i) { | ||||||||||||||||
| float p = cur_p->data[i].p; | ||||||||||||||||
| float normalized_distance = std::abs(p - computed_target) / distribution_width; | ||||||||||||||||
| cur_p->data[i].logit = peak_logit_value / (1.0f + std::pow(normalized_distance, tail_heaviness)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| llama_sampler_softmax_impl(cur_p, false); | ||||||||||||||||
|
|
||||||||||||||||
| // sample from the transformed distribution | ||||||||||||||||
| const int idx = llama_sample_dist(cur_p, ctx->rng); | ||||||||||||||||
| cur_p->selected = idx; | ||||||||||||||||
|
|
||||||||||||||||
| // add the ORIGINAL probability to the rolling window | ||||||||||||||||
| ctx->window.push_back(original_probs[idx]); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| static void llama_sampler_power_law_reset(struct llama_sampler * smpl) { | ||||||||||||||||
| auto * ctx = (llama_sampler_power_law *) smpl->ctx; | ||||||||||||||||
| ctx->window = ring_buffer<float>(ctx->window_size); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) { | ||||||||||||||||
| const auto * ctx = (const llama_sampler_power_law *) smpl->ctx; | ||||||||||||||||
| auto * result = llama_sampler_init_power_law(ctx->target, ctx->window_size, ctx->seed); | ||||||||||||||||
| auto * result_ctx = (llama_sampler_power_law *) result->ctx; | ||||||||||||||||
|
|
||||||||||||||||
|
Comment on lines
+3384
to
+3397
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should apply the same logic for the RNG seeds here as in the llama.cpp/src/llama-sampling.cpp Lines 1111 to 1117 in 516a4ca
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. I addressed this in |
||||||||||||||||
| result_ctx->rng = ctx->rng; | ||||||||||||||||
| result_ctx->window = ctx->window; | ||||||||||||||||
|
|
||||||||||||||||
| return result; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| static void llama_sampler_power_law_free(struct llama_sampler * smpl) { | ||||||||||||||||
| delete (llama_sampler_power_law *) smpl->ctx; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| static struct llama_sampler_i llama_sampler_power_law_i = { | ||||||||||||||||
| /* .name = */ llama_sampler_power_law_name, | ||||||||||||||||
| /* .accept = */ nullptr, | ||||||||||||||||
| /* .apply = */ llama_sampler_power_law_apply, | ||||||||||||||||
| /* .reset = */ llama_sampler_power_law_reset, | ||||||||||||||||
| /* .clone = */ llama_sampler_power_law_clone, | ||||||||||||||||
| /* .free = */ llama_sampler_power_law_free, | ||||||||||||||||
| }; | ||||||||||||||||
|
|
||||||||||||||||
| struct llama_sampler * llama_sampler_init_power_law( | ||||||||||||||||
| float target, | ||||||||||||||||
| int32_t window_size, | ||||||||||||||||
| uint32_t seed | ||||||||||||||||
| ) { | ||||||||||||||||
| auto seed_cur = get_rng_seed(seed); | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This See how the |
||||||||||||||||
| return llama_sampler_init( | ||||||||||||||||
| /* .iface = */ &llama_sampler_power_law_i, | ||||||||||||||||
| /* .ctx = */ new llama_sampler_power_law { | ||||||||||||||||
| /* .target = */ target, | ||||||||||||||||
| /* .window_size = */ window_size, | ||||||||||||||||
| /* .seed = */ seed_cur, | ||||||||||||||||
| /* .rng = */ std::mt19937(seed_cur), | ||||||||||||||||
| /* .window = */ ring_buffer<float>(window_size), | ||||||||||||||||
| } | ||||||||||||||||
| ); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // logit-bias | ||||||||||||||||
|
|
||||||||||||||||
| struct llama_sampler_logit_bias { | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these parameters be configurable like in the original implementation? There is probably a tradeoff with feature creep, having too many options for users to control, but some of these seem potentially important (especially
distribution_width). Also, I noticedpeak_logit_valueis outside the range suggested in the original implementation; is that intentional?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Myself and the original author are discussing the parameters over the next few days, I agree that the current implementation is probably not ideal, which is why I marked it back as draft.
I will post a comment in the main thread with an update once we've got it more figured out. Thank you!