DRY sampler performance optimization#6047
Conversation
It's really nice that you managed to reduce the time complexity of the algorithm, well done! Some quick comments:
|
|
Wow, that's awesome! A few initial comments: 👍 for the list conversion, as that's a simple change with huge benefits that I overlooked. Thank you for finding and resolving this! I'm less enthusiastic about the other change introducing the Z algorithm. Notably, I don't believe that performance on "adversarial" inputs such as 10k repeated tokens is a good justification for such a big algorithmic change. I spent literally hundreds of hours validating DRY in its current form, often stepping through the algorithm manually and quadruple-checking every single statement. I took a lot of care to produce an algorithm that is as easy to follow as possible, and rewrote everything multiple times in the process, because testing samplers is really, really hard and I wanted an implementation that can be easily ported to other loaders. IMO, the proposed algorithm, while faster in pathological cases, is significantly more difficult to understand and reason about. It's great that you commented everything and even wrote tests to verify identical behavior to the current implementation, but unless the performance difference can be demonstrated to matter with real-world inputs, I don't think the benefits outweigh the additional complexity burden here. But of course, that's ultimately not my decision to make.
Yes, I had anticipated that in theory and then forgotten about it. I believe it's a simple floating point overflow in the exponential penalty calculation. Note that this cannot occur unless the input already contains pathological repetitions, as DRY will prevent repetitive sequence buildups long before they can reach such lengths. Clamping the exponent in the penalty formula should be sufficient to fix this.
Nope, it's exactly what it says on the can. Here's the relevant code: text-generation-webui/modules/sampler_hijack.py Lines 285 to 287 in 852c943 You're probably confused by the
Not quite. Only the last token of each string is used, after prefixing the string with Token IDs have the massive downside that they depend on the tokenizer, and there are quite a few tokenizers around now, with model makers continuing to experiment with vocabularies. I fear that expecting users to provide different token IDs every time they switch to another model would make this crucial parameter, and thus DRY itself, useless in practice, and above all else I want DRY to be practical to use. |
|
Thank you both for such a quick response!
@oobabooga Can you please give more detailed instructions on how to run this sanity check? It seems to me like the logits are not stable across generations. Even when I checkout a commit of the dev branch before my changes, and I generate with the exact same parameters and the exact same prompt, the logits after sampling are slightly different every time. Here's an example of what I see when I try to print logits: ...and when I run the same thing again expecting to see the same logits, they are similar but not exactly the same: I'm trying to print the logits inside file To me it seems like that would be the correct place to print the logits, but maybe you had some different place in mind? Or maybe this issue is related to ExllamaV2 nondeterminism? I only have a 3060 so I'm unable to run any of the llama models without quantization. Even if the logits were stable across generations, there would still be the problem of how to compare them. I can't eyeball through 128000 values, so I guess I would need to hash them and compare hashes or something? Maybe something like Alternatively, I also tried to check use the API endpoint to check logits: |
I have now pushed a fix to the OverflowError as a separate commit. |
Sorry, this one is my bad. As @p-e-w explained, the current behavior is correct and no change to |
You're right. Let's forget my original suggestion regarding sequence breakers.
Thanks for clarifying! I think I now understand better how sequence breakers are intended to work. I still think the current behavior of sequence breakers is far from optimal and the behavior should be changed in some way. Consider the following example.
To recap, in this example we have added our character "jack black" into sequence breakers... but it will only contain the token "ack". What kind of issues does this cause? Well, let's imagine that the LLM is generating tokens and it is currently in a state where the most recent tokens look like this: How would we want the DRY sampler with sequence breakers to work in this situation?
We need to change the sequence breakers in some way to prevent weird behavior from names that tokenize in unlucky ways.
I like this idea. However, we still need to find some way to address ambiguous tokenizations. The current approach of prefixing sequence breaker strings with One solution to ambiguous tokenizations would be to simply compute all of them. For example, if a user inputs sequence breaker string "jack black:", then we would find all tokenizations which end in that string. A few of them are:
...so sequence_breakers might look like this:
An entire sequence would need to match (not just individual token from a sequence). Finding these alternative tokenizations wouldn't be computationally expensive (when we consider that it only needs to be done once, and we consider that we only care about tokenizations which are possible in reality, rather than all possible combinations). However, it's possible in some cases to end up with so many alternative tokenizations, that it might cause some performance issue during the sampling. Maybe this is still the best approach? Maybe we can give the user a warning if the number of sequence breaker token sequences is excessive? I'm not sure about this. |
I agree @p-e-w In my defense, all of the complexity is contained within the code that computes the z array. The actual sampler logic is still written in a very simple and understandable way. All that sequence breaker stuff, penalty parameters, all of that, is still just as simple as it was in your original implementation. I realize now that it wasn't a good choice from me to put the z array computation in the same place as the sampler logic, as that will make the sampler look daunting to anyone who wants to tinker with it. I've now moved the z array computation to be a utility function. So now the sampler logic just calls this utility function, and all the complexity is isolated into that function: If someone wants to tinker with the sampler logic, they can easily do that without diving into how that utility function works. It's enough if they understand the problem that the function is solving (what it means to "count matches" in this context), what the input format needs to be (a list of token ids) and what the output format is (which is a little bit more complicated but should be clear from the commented example).
If we only merge the data type optimization, and we don't merge the z algorithm optimization, will we have practical use cases where the performance difference matters? Or does it matter only with adversarial cases? I don't know. I don't have long context use cases personally, so we would need to have someone like @Hunterius8 to benchmark the results to know what kind of a difference it makes. My gut feeling is same as yours that the z algorithm optimization may not matter too much for practical use cases. Regarding adversarial cases, please keep in mind that TGW should support not only UI users but API users as well. If I have a web app serving many users, I don't want one malicious user to be able to easily take down the whole service by simply sending a string that repeats the same word 10000 times. (Of course we can have all kinds of restrictions between the web app frontend and oobabooga API, I'm not saying this algorithm change is the only way to defend from issues like this, but generally when you build API services, you want them to be robust against adversarial inputs.) Ultimately it's oobabooga's decision which parts of the PR they want to merge. |
|
How about spinning off the list conversion and overflow fix into a separate PR? That should be quick to review and merge, and then we can discuss the other change here.
AFAIK, text-generation-webui is not designed as, and not suitable for use as, an "API service". It has only rudimentary authentication, no user management, no rate limits etc., and it can serve only a single concurrent request. It's a tool for individuals to work with LLMs locally. To be safe to run as a service for untrusted users, it would probably have to be rewritten from scratch. Not to mention that, as you indicate, computational DoS should be dealt with using computational limits. BTW, self-attention, which is core to how most LLMs work, is itself a quadratic algorithm in both time and memory. So without limits in place (context length in this case), runaway complexity is unavoidable anyway. |
API services are not typically deployed in the manner that you imply. For example, a company might have 20 different API services running. Do each of those 20 services also separately contain authentication, user management, rate limits, etc? Of course not. A company might be using, for example, AWS API Gateway to enforce rate limits: the same solution used in front of all those 20 API services, rather than implementing rate limits separately in each of those 20 services. Likewise for user management and user authentication. If you have an API service like TGW running, you would typically have just one backend service which connects to TGW, and TGW already has perfectly adequate ways to do authentication for that. Actual end user authentication, rate limits, etc. would be performed by a different application. The prospect of using TGW to serve multiple users via the API is not a theoretical prospect. The API features have existed in TGW for a long time and they are used for more than just single-user use cases. I am personally powering a web application for multiple users with TGW, and have been for quite a while. There is no need to "rewrite the whole thing from scratch". It already works great! In fact, TGW is currently the cheapest and most feature complete API option for serving LLMs at low scale. Once you need to scale up it becomes expensive to run, and something like vLLM will be a better option. But for low scale use it is much cheaper than vLLM. And every application initially starts off as a low scale application. Very few of them eventually grow into high scale.
Different algorithms have different time complexities. Every algorithm can't be optimized to linear time. This one can be. It's a really odd thing to say that because some different algorithm has quadratic time complexity, this one should as well... when we know that it can run in linear time. |
The TGW API doesn't even support concurrent inference, and AFAICT makes no effort to sanitize input in any way (beyond basic request parsing). I believe you when you say you are using it in a production webservice, but it's quite obviously not designed for that. Not to mention that it has a huge Python application attached to it that loads all kinds of stuff that the API doesn't need, and which could have any number of security implications that nobody ever audits for because 99% of users don't use it to power a production app.
That is not at all what I said. What I wrote is that quadratic complexity during LLM inference is unavoidable because inference itself is quadratic in the context length. I didn't say DRY "should" be quadratic. However, the following two things are true:
|
This is going way off topic, but... I'm aware of the security implications. My web backend takes untrusted user input, tokenizes it, and constructs a prompt at token level. Before the web backend sends the prompt to TGW, the prompt tokens are decoded into a string. The reason why I wrote it like that is I need to stay below the context length limit, so I need to be aware of token count during the construction of the prompt. A side effect from this construction is that untrusted user input is effectively sanitized before fed into TGW. The user can only affect the prompt (not other generation parameters), and they can only affect the prompt with text that has a valid tokenization. Secondly, most web applications are flaming piles of garbage in terms of security. Yes, I'm sure you could audit TGW and find loads of vulnerabilities, just like you find loads of vulnerabilities in every other application. But when you're running TGW in a docker container on a cloud host, and you're sanitizing user input on a separate web backend, the risk of something really bad happening is pretty small. Third, when you say that TGW is "quite obviously not designed for that"... well, it's not obvious to me. As I said already, I'm not aware of any alternatives which would be better for a low scale web app. Can you name any? Also, TGW has absolutely no warnings like "only for single user purposes! do not use in production!". TGW provides an API and documents that API, and it just happens to be the best API available... but you are here telling me that TGW is "obviously not designed for that". Obvious how? |
Okay. @p-e-w can you please review: #6053 I also realized that as we apply the cap for max repetition length to prevent overflow error, we can apply the cap during the looping phase, so we never loop more than 1000 tokens backwards. That gives us a very good performance boost for the adversarial "worst case" inputs. |
|
Tested it by just pasting part of a long article into the webui and generating more text. Context length was set to 21504. Massive improvements over the old version. Went from about 10.5t/s at the maximum length previously, to 21.5t/s now. |
|
Great to see real-world numbers! Could you test the much simpler PR #6053 for comparison? |
|
Closing in favor of #6053 |
|
@p-e-w Now that DRY has finally been merged to main, I feel like maybe I can bring this up again without risking further delays... Are you interested in sequence breakers improvements? Regarding #6047 (comment) I'm guessing that you wouldn't view the benefit/complexity tradeoff as favorable? I'm asking because I don't want to sink work into a black hole, as you put it. So if you're interested, I'll work on this, but if not, I'll work on something else. |
|
Sorry for the delay, I missed the notification somehow. I do agree that improving the sequence breakers mechanism is worth exploring. That being said, I'm not completely convinced that there is a real problem here. Of course it is possible to construct examples where something suboptimal happens, as you have done in the linked comment. But sequence breakers, and DRY itself, are only heuristics. The system doesn't fall apart if a repetition isn't prevented in some very specific circumstances. Any change that increases complexity should provide demonstrably better results in non-constructed scenarios to be worth its weight. But I'm definitely open to discussing potential improvements. |
Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047
* Silence bogus Clippy warning Clippy's suggestion cannot be implemented because of borrowing issues * Get rid of unnecessary type annotations Interesting that Clippy doesn't catch this * Store default sequence breakers in a slice It's nicer when the length is not hardcoded * Make default sequence breakers private No need to leak this as it's not used elsewhere * Limit match length Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047 * "Fix" sequence breaker tokenization Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result. See LostRuins/koboldcpp#982 for a correct solution to this problem.
* Implement dry penalty * Add dry sampling params to requests * Handle it * Clippy * Review: "Implement DRY penalty" (#645) * Silence bogus Clippy warning Clippy's suggestion cannot be implemented because of borrowing issues * Get rid of unnecessary type annotations Interesting that Clippy doesn't catch this * Store default sequence breakers in a slice It's nicer when the length is not hardcoded * Make default sequence breakers private No need to leak this as it's not used elsewhere * Limit match length Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047 * "Fix" sequence breaker tokenization Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result. See LostRuins/koboldcpp#982 for a correct solution to this problem. * Nicer * Even better * Complete merge * Fix saturating sub * Handle when no context * Make context the entire sequence and refactor * Remove slicing for all * Fix the bug with penalty Credit to @p-e-w for finding this! Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com> * Add custom logits processor API (#702) * Add custom logits processor api * Typos * Nicer interface and update example * Fix doctest * Update docs * Update exports * Add Gemma 2 PagedAttention support (#704) * Add gemma2 paged attn support * Non cuda support? * Remove error * It works * Faster RmsNorm in gemma/gemma2 (#703) * Fix bug in metal isq (#706) * Support GGUF BF16 tensors (#691) * Support GGUF bf16 tensors * Fix loading of bf16 ggml tensor * Fix dequant of bf16 * Use merged rev * Softcapping, real batching + sliding window support for Flash Attention (#707) * Flash attention varlen kind of works * Seems to work * Now it's nice * Sliding window support and clippy * Remove warning * Support smollm * Update rev to match merged * Remove some usages of 'pub' in models (#708) * Support the Phi 3.5 V model (#710) * Update image_seq_len * Update the examples * Format * Implement the Phi 3.5 MoE model (#709) * Copy the model * Add most of it * Add the blocksparse moe parts * Clippy * Fix mscales * A batch of fixes * Correctly cast it * Handle isq on gate * Even more progress * Runs now * Clippy * Fix to use layernorm * Remove unused * Add docs * Add more docs * Apply review comments * Update readme --------- Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>

Background
The DRY sampler by @p-e-w was recently merged to dev. It seems to be a significant improvement in the fight against repetition, but there were concerns about the performance impact of the sampler. When @oobabooga merged the PR, they requested ideas to make it faster, so I decided to tackle this as my first contribution to text-generation-webui.
Performance impact of this PR
Root causes of the performance issues and how they were resolved
The original implementation loops over input ids in a torch tensor and accesses them like this:
input_ids_row[index].item(). I discovered that torch adds a lot of overhead to this operation, making it extremely slow compared to the corresponding operation on a normal Python list. So we can trivially make the sampler a lot faster just by transforming the ids to a list. I isolated this change to commit 7c03e4a so you can check that commit separately if you want. This change improves the performance on 1000 repetitive tokens to almost 0 seconds, and it improves the performance on 10000 repetitive tokens to 17 seconds. That's not good enough, so we need one more optimization:The original implementation finds matching strings using a simple O(n^2) algorithm. For example, if the input size n=10000 repetitive tokens, then the string matching requires at least 10000 ^ 2 = 100 million steps (in the worst case). My optimized implementation finds matching strings using Z algorithm, which is a faster O(n) algorithm. Let's continue with the example of 10000 repetitive tokens to explain: O(n) means that instead of running 100 million steps, we only have to run 10000 steps - as the size of input increases, the number of steps increases linearly in relation to the size of the input. Since Z algorithm introduces some complexity in the code, I made an effort to write the algorithm as readable as possible and to document the steps. I also did extensive testing to verify correctness, since it's no longer obvious just by reading the code if it's correct or not.
How to review this PR
I understand that reviewing first-time contributors' PRs is often annoying, so I made an effort to make the review process as easy as possible by attaching a testing helper file (temp1.py). You can run this file as a script independently of oobabooga (e.g.
python temp1.py). It will run test cases where each test case is ran against both the original DRY implementation and the optimized DRY implementation. Running times are measured and results are compared. Error is thrown if results are different. If you look at the end of the temp1 file, you can see how you can add your own test cases there. You can, for example, tokenize a long prompt and copypaste the tokens into the temp1 file. At the end of the file you can also see a test generator, which generates random test data.Note that I did almost all of my testing in the temp1 file, so it would be useful for someone to provide fresh eyes that the temp1 file corresponds to the actual changes in
sampler_hijack.pyand work there properly.Related observations
The following observations are not directly related to this PR, but they are observations I made about the DRY sampler while working on it:
I think parameterallowed_lengthshould be either renamed or its functionality should be changed, because the current behavior is "If the length of the matching sequence is less than allowed_length, no penalty is applied". For example, if allowed_length is 2, then it means the longest allowed length is 1 (not 2!). I expect this to confuse users if it's not changed.which if I understand correctly, will each act as a sequence breaker(edit: only last token of each string will be used). So if a really common token happens to end up in sequence breakers unbeknownst to the user, then it will hamper the behavior of the DRY sampler and the user might be dumbfounded as to why they are seeing so much repetition. By changing this to token ids we make it explicit to users what the inputs must be.Checklist: