Skip to content

DRY sampler performance optimization#6047

Closed
belladoreai wants to merge 4 commits intooobabooga:devfrom
belladoreai:dev-dry-optimization
Closed

DRY sampler performance optimization#6047
belladoreai wants to merge 4 commits intooobabooga:devfrom
belladoreai:dev-dry-optimization

Conversation

@belladoreai
Copy link
Copy Markdown
Contributor

@belladoreai belladoreai commented May 23, 2024

Background

The DRY sampler by @p-e-w was recently merged to dev. It seems to be a significant improvement in the fight against repetition, but there were concerns about the performance impact of the sampler. When @oobabooga merged the PR, they requested ideas to make it faster, so I decided to tackle this as my first contribution to text-generation-webui.

Performance impact of this PR

  • The "worst case" scenario for the original DRY implementation is an input which contains nothing else except the same token repeated many times. A repeating input of 1024 tokens takes 38 seconds to run on my machine with the original implementation, compared to ~0.00 seconds with my optimized implementation. If we increase the token count from 1024 tokens to 10000 tokens, the original implementation will basically run forever without finishing.
  • For most practical scenarios the performance impact is negligible. I tested with a book excerpt of 1800 tokens and both the original implementation and optimized implementation took ~0.00 seconds to run.
  • Some practical scenarios - particularly those with long context and repetitive sequences - should significantly benefit from this PR. @Hunterius8 noted that with 19k context the original DRY implementation was causing +50% slowdown to t/s. This optimized implementation should completely eliminate that slowdown.

Root causes of the performance issues and how they were resolved

  1. The original implementation loops over input ids in a torch tensor and accesses them like this: input_ids_row[index].item(). I discovered that torch adds a lot of overhead to this operation, making it extremely slow compared to the corresponding operation on a normal Python list. So we can trivially make the sampler a lot faster just by transforming the ids to a list. I isolated this change to commit 7c03e4a so you can check that commit separately if you want. This change improves the performance on 1000 repetitive tokens to almost 0 seconds, and it improves the performance on 10000 repetitive tokens to 17 seconds. That's not good enough, so we need one more optimization:

  2. The original implementation finds matching strings using a simple O(n^2) algorithm. For example, if the input size n=10000 repetitive tokens, then the string matching requires at least 10000 ^ 2 = 100 million steps (in the worst case). My optimized implementation finds matching strings using Z algorithm, which is a faster O(n) algorithm. Let's continue with the example of 10000 repetitive tokens to explain: O(n) means that instead of running 100 million steps, we only have to run 10000 steps - as the size of input increases, the number of steps increases linearly in relation to the size of the input. Since Z algorithm introduces some complexity in the code, I made an effort to write the algorithm as readable as possible and to document the steps. I also did extensive testing to verify correctness, since it's no longer obvious just by reading the code if it's correct or not.

How to review this PR

I understand that reviewing first-time contributors' PRs is often annoying, so I made an effort to make the review process as easy as possible by attaching a testing helper file (temp1.py). You can run this file as a script independently of oobabooga (e.g. python temp1.py). It will run test cases where each test case is ran against both the original DRY implementation and the optimized DRY implementation. Running times are measured and results are compared. Error is thrown if results are different. If you look at the end of the temp1 file, you can see how you can add your own test cases there. You can, for example, tokenize a long prompt and copypaste the tokens into the temp1 file. At the end of the file you can also see a test generator, which generates random test data.

Note that I did almost all of my testing in the temp1 file, so it would be useful for someone to provide fresh eyes that the temp1 file corresponds to the actual changes in sampler_hijack.py and work there properly.

Related observations

The following observations are not directly related to this PR, but they are observations I made about the DRY sampler while working on it:

  • When running really large repetitive inputs, the sampler crashes due to OverflowError. My optimized implementation does not change the behavior of the sampler, so my implementation crashes with the same inputs as the original implementation crashes. This is trivial to fix, just let me know in the comments and I'll add it (I didn't add a fix yet because I wanted to be able to say "this PR improves performance without changing behavior").
  • I think parameter allowed_length should be either renamed or its functionality should be changed, because the current behavior is "If the length of the matching sequence is less than allowed_length, no penalty is applied". For example, if allowed_length is 2, then it means the longest allowed length is 1 (not 2!). I expect this to confuse users if it's not changed.
  • I think sequence breakers parameter should be changed so that instead of strings, it takes token ids as input. Otherwise users will expect any string to act as a sequence breaker. The actual behavior is that strings are tokenized and some string might tokenize into multiple tokens, which if I understand correctly, will each act as a sequence breaker (edit: only last token of each string will be used). So if a really common token happens to end up in sequence breakers unbeknownst to the user, then it will hamper the behavior of the DRY sampler and the user might be dumbfounded as to why they are seeing so much repetition. By changing this to token ids we make it explicit to users what the inputs must be.

Checklist:

@oobabooga
Copy link
Copy Markdown
Owner

My optimized implementation finds matching strings using Z algorithm, which is a faster O(n) algorithm.

It's really nice that you managed to reduce the time complexity of the algorithm, well done!

Some quick comments:

  1. A simple sanity check is to paste some text in the Default or Notebook tabs and get the logits after sampling parameters in the dev branch and this PR branch, just to see if they are identical.
  2. A fix to the OverflowError can be added as a new commit inside this PR. It would be good to run the sanity check above after this change.
  3. allowed_length could perhaps be renamed to minimum_penalized_length. I'll tag the original author @p-e-w for his opinion.
  4. Is it perhaps possible to use sequences of tokens as a sequence breaker rather than individual tokens? So that the problem you mentioned doesn't occur and the parameter becomes more literal.

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented May 24, 2024

Wow, that's awesome! A few initial comments:

👍 for the list conversion, as that's a simple change with huge benefits that I overlooked. Thank you for finding and resolving this!

I'm less enthusiastic about the other change introducing the Z algorithm. Notably, I don't believe that performance on "adversarial" inputs such as 10k repeated tokens is a good justification for such a big algorithmic change. I spent literally hundreds of hours validating DRY in its current form, often stepping through the algorithm manually and quadruple-checking every single statement. I took a lot of care to produce an algorithm that is as easy to follow as possible, and rewrote everything multiple times in the process, because testing samplers is really, really hard and I wanted an implementation that can be easily ported to other loaders.

IMO, the proposed algorithm, while faster in pathological cases, is significantly more difficult to understand and reason about. It's great that you commented everything and even wrote tests to verify identical behavior to the current implementation, but unless the performance difference can be demonstrated to matter with real-world inputs, I don't think the benefits outweigh the additional complexity burden here. But of course, that's ultimately not my decision to make.

When running really large repetitive inputs, the sampler crashes due to OverflowError.

Yes, I had anticipated that in theory and then forgotten about it. I believe it's a simple floating point overflow in the exponential penalty calculation. Note that this cannot occur unless the input already contains pathological repetitions, as DRY will prevent repetitive sequence buildups long before they can reach such lengths.

Clamping the exponent in the penalty formula should be sufficient to fix this.

I think parameter allowed_length should be either renamed or its functionality should be changed, because the current behavior is "If the length of the matching sequence is less than allowed_length, no penalty is applied".

Nope, it's exactly what it says on the can. Here's the relevant code:

if match_length >= self.allowed_length:
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
scores_row[token] -= penalty

You're probably confused by the >= operator, which looks like the penalty would apply already to sequences that are equal in length to allowed_length. But note that penalties are always applied to the token about to be sampled. That is, we already have a repeated sequence of length allowed_length, so we penalize the next token in the probability distribution, because otherwise that token would make the repeated sequence longer than allowed_length, which is not allowed.

allowed_length is the maximum length that a repeated sequence can reach before penalties are applied to tokens in that sequence. As far as I'm concerned, this is what I would expect from the name. If allowed_length is 3, DRY will allow the input one two three one two to be continued without penalizing three, because one two three has length 3.

I think sequence breakers parameter should be changed so that instead of strings, it takes token ids as input. Otherwise users will expect any string to act as a sequence breaker. The actual behavior is that strings are tokenized and some string might tokenize into multiple tokens, which if I understand correctly, will each act as a sequence breaker.

Not quite. Only the last token of each string is used, after prefixing the string with a in order to avoid problems with ambiguous tokenizations which exist particularly for punctuation.

Token IDs have the massive downside that they depend on the tokenizer, and there are quite a few tokenizers around now, with model makers continuing to experiment with vocabularies. I fear that expecting users to provide different token IDs every time they switch to another model would make this crucial parameter, and thus DRY itself, useless in practice, and above all else I want DRY to be practical to use.

@belladoreai
Copy link
Copy Markdown
Contributor Author

Thank you both for such a quick response!

  1. A simple sanity check is to paste some text in the Default or Notebook tabs and get the logits after sampling parameters in the dev branch and this PR branch, just to see if they are identical.

@oobabooga Can you please give more detailed instructions on how to run this sanity check?

It seems to me like the logits are not stable across generations. Even when I checkout a commit of the dev branch before my changes, and I generate with the exact same parameters and the exact same prompt, the logits after sampling are slightly different every time.

Here's an example of what I see when I try to print logits:

tensor([[[ 9.2031, 10.8047,  4.8984,  ..., -3.6758, -3.6758, -3.6777]]],

...and when I run the same thing again expecting to see the same logits, they are similar but not exactly the same:

tensor([[[ 9.2188, 10.7969,  4.9062,  ..., -3.6973, -3.6973, -3.6992]]]

I'm trying to print the logits inside file exllamav2_hf.py at the end of the __call__ before this line is called:

return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)

To me it seems like that would be the correct place to print the logits, but maybe you had some different place in mind?

Or maybe this issue is related to ExllamaV2 nondeterminism? I only have a 3060 so I'm unable to run any of the llama models without quantization.

Even if the logits were stable across generations, there would still be the problem of how to compare them. I can't eyeball through 128000 values, so I guess I would need to hash them and compare hashes or something? Maybe something like hash(frozenset(logits[0][0].tolist()))?

Alternatively, I also tried to check use the API endpoint to check logits: http://127.0.0.1:5000/v1/internal/logits. Something seems to be wrong with that endpoint. Making a call to it will hang forever and also causes the entire API to become unresponsive to all requests. Here's an example of one of the API calls I tried, which causes this issue:

{
    "prompt": "Who is best, Asuka or Rei? Answer:",
    "use_samplers": true,
    "top_k": 3,
    "max_tokens": 1,
    "max_new_tokens": 1
}

@belladoreai
Copy link
Copy Markdown
Contributor Author

  1. A fix to the OverflowError can be added as a new commit inside this PR. It would be good to run the sanity check above after this change.

I have now pushed a fix to the OverflowError as a separate commit.

@belladoreai
Copy link
Copy Markdown
Contributor Author

  1. allowed_length could perhaps be renamed to minimum_penalized_length

Sorry, this one is my bad. As @p-e-w explained, the current behavior is correct and no change to allowed_length is needed.

@belladoreai
Copy link
Copy Markdown
Contributor Author

belladoreai commented May 24, 2024

(quoting pew) Token IDs have the massive downside that they depend on the tokenizer, and there are quite a few tokenizers around now, with model makers continuing to experiment with vocabularies. I fear that expecting users to provide different token IDs every time they switch to another model would make this crucial parameter, and thus DRY itself, useless in practice, and above all else I want DRY to be practical to use.

You're right. Let's forget my original suggestion regarding sequence breakers.

(quoting pew) Only the last token of each string is used, after prefixing the string with a in order to avoid problems with ambiguous tokenizations which exist particularly for punctuation.

Thanks for clarifying! I think I now understand better how sequence breakers are intended to work.

I still think the current behavior of sequence breakers is far from optimal and the behavior should be changed in some way.

Consider the following example.

  1. User is chatting with a character named "jack black"
  2. User follows the instruction to configure names in sequence breakers as separate strings: ["jack", "black"]
  3. Each string is prefixed with a: ["ajack", "ablack"]
  4. Each prefixed string is tokenized using llama 3 tokenization into the following tokens: [["aj", "ack"], ["abl", "ack"]]
  5. Only the last token of each tokenized string is used: ["ack", "ack"]

To recap, in this example we have added our character "jack black" into sequence breakers... but it will only contain the token "ack". What kind of issues does this cause? Well, let's imagine that the LLM is generating tokens and it is currently in a state where the most recent tokens look like this:

You: what you got on your tooth?
jack black: there's plack on my tooth and there's plack on my tooth and there's plack on my tooth and there's plack on my tooth and

How would we want the DRY sampler with sequence breakers to work in this situation?

  1. We would want the sampler to recognize the repetitive sequence "and there's plack on my tooth" and to stop repeating it. Does this happen? No, because the string " plack" tokenizes in llama 3 tokenization into tokens [" pl", "ack"], and the token "ack" is in sequence breakers. So we mistakenly apply a sequence breaker to a commonly occurring token, thus hampering the DRY sampler and preventing it from doing its job of preventing repetition.
  2. Another thing we would want the sampler to recognize is that the sequence "jack black:" should not be considered as repetition. Does this happen? No, because it tokenizes into ["jack", " black", ":"]. This tokenization does not include our sequence breaker token "ack".

We need to change the sequence breakers in some way to prevent weird behavior from names that tokenize in unlucky ways.

(quoting oobabooga) Is it perhaps possible to use sequences of tokens as a sequence breaker rather than individual tokens? So that the problem you mentioned doesn't occur and the parameter becomes more literal.

I like this idea.

However, we still need to find some way to address ambiguous tokenizations. The current approach of prefixing sequence breaker strings with a is a hacky solution which will sometimes cause good outcomes and sometimes cause bad outcomes (as demonstrated with the "jack black" example). We should do something else.

One solution to ambiguous tokenizations would be to simply compute all of them. For example, if a user inputs sequence breaker string "jack black:", then we would find all tokenizations which end in that string. A few of them are:

  • ["jack", " black", ":"]
  • [" jack", " black", ":"]
  • ["aj", "ack", " black", ":"]

...so sequence_breakers might look like this:

[["jack", " black", ":"], [" jack", " black", ":"], ["aj", "ack", " black", ":"]]

An entire sequence would need to match (not just individual token from a sequence).

Finding these alternative tokenizations wouldn't be computationally expensive (when we consider that it only needs to be done once, and we consider that we only care about tokenizations which are possible in reality, rather than all possible combinations). However, it's possible in some cases to end up with so many alternative tokenizations, that it might cause some performance issue during the sampling. Maybe this is still the best approach? Maybe we can give the user a warning if the number of sequence breaker token sequences is excessive?

I'm not sure about this.

@belladoreai
Copy link
Copy Markdown
Contributor Author

belladoreai commented May 24, 2024

IMO, the proposed algorithm, while faster in pathological cases, is significantly more difficult to understand and reason about.

I agree @p-e-w

In my defense, all of the complexity is contained within the code that computes the z array. The actual sampler logic is still written in a very simple and understandable way. All that sequence breaker stuff, penalty parameters, all of that, is still just as simple as it was in your original implementation.

I realize now that it wasn't a good choice from me to put the z array computation in the same place as the sampler logic, as that will make the sampler look daunting to anyone who wants to tinker with it. I've now moved the z array computation to be a utility function. So now the sampler logic just calls this utility function, and all the complexity is isolated into that function:

# Find where the repetitions are: for each position in the input, count how many tokens to the left of that position (inclusive) are repeating the end of the input.
# example input:  [7,6,7,5,6,7]
# example output: [1,0,2,0,0,-] (exclude last token from consideration)
match_lengths = self.count_matches(input_ids_list)

If someone wants to tinker with the sampler logic, they can easily do that without diving into how that utility function works. It's enough if they understand the problem that the function is solving (what it means to "count matches" in this context), what the input format needs to be (a list of token ids) and what the output format is (which is a little bit more complicated but should be clear from the commented example).

I don't believe that performance on "adversarial" inputs such as 10k repeated tokens is a good justification for such a big algorithmic change

If we only merge the data type optimization, and we don't merge the z algorithm optimization, will we have practical use cases where the performance difference matters? Or does it matter only with adversarial cases? I don't know. I don't have long context use cases personally, so we would need to have someone like @Hunterius8 to benchmark the results to know what kind of a difference it makes. My gut feeling is same as yours that the z algorithm optimization may not matter too much for practical use cases.

Regarding adversarial cases, please keep in mind that TGW should support not only UI users but API users as well. If I have a web app serving many users, I don't want one malicious user to be able to easily take down the whole service by simply sending a string that repeats the same word 10000 times. (Of course we can have all kinds of restrictions between the web app frontend and oobabooga API, I'm not saying this algorithm change is the only way to defend from issues like this, but generally when you build API services, you want them to be robust against adversarial inputs.)

Ultimately it's oobabooga's decision which parts of the PR they want to merge.

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented May 25, 2024

How about spinning off the list conversion and overflow fix into a separate PR? That should be quick to review and merge, and then we can discuss the other change here.

Regarding adversarial cases, please keep in mind that TGW should support not only UI users but API users as well. If I have a web app serving many users, I don't want one malicious user to be able to easily take down the whole service by simply sending a string that repeats the same word 10000 times. (Of course we can have all kinds of restrictions between the web app frontend and oobabooga API, I'm not saying this algorithm change is the only way to defend from issues like this, but generally when you build API services, you want them to be robust against adversarial inputs.)

AFAIK, text-generation-webui is not designed as, and not suitable for use as, an "API service". It has only rudimentary authentication, no user management, no rate limits etc., and it can serve only a single concurrent request. It's a tool for individuals to work with LLMs locally. To be safe to run as a service for untrusted users, it would probably have to be rewritten from scratch. Not to mention that, as you indicate, computational DoS should be dealt with using computational limits.

BTW, self-attention, which is core to how most LLMs work, is itself a quadratic algorithm in both time and memory. So without limits in place (context length in this case), runaway complexity is unavoidable anyway.

@belladoreai
Copy link
Copy Markdown
Contributor Author

belladoreai commented May 25, 2024

AFAIK, text-generation-webui is not designed as, and not suitable for use as, an "API service". It has only rudimentary authentication, no user management, no rate limits etc., and it can serve only a single concurrent request. It's a tool for individuals to work with LLMs locally. To be safe to run as a service for untrusted users, it would probably have to be rewritten from scratch. Not to mention that, as you indicate, computational DoS should be dealt with using computational limits.

API services are not typically deployed in the manner that you imply. For example, a company might have 20 different API services running. Do each of those 20 services also separately contain authentication, user management, rate limits, etc? Of course not. A company might be using, for example, AWS API Gateway to enforce rate limits: the same solution used in front of all those 20 API services, rather than implementing rate limits separately in each of those 20 services. Likewise for user management and user authentication. If you have an API service like TGW running, you would typically have just one backend service which connects to TGW, and TGW already has perfectly adequate ways to do authentication for that. Actual end user authentication, rate limits, etc. would be performed by a different application.

The prospect of using TGW to serve multiple users via the API is not a theoretical prospect. The API features have existed in TGW for a long time and they are used for more than just single-user use cases. I am personally powering a web application for multiple users with TGW, and have been for quite a while. There is no need to "rewrite the whole thing from scratch". It already works great! In fact, TGW is currently the cheapest and most feature complete API option for serving LLMs at low scale. Once you need to scale up it becomes expensive to run, and something like vLLM will be a better option. But for low scale use it is much cheaper than vLLM. And every application initially starts off as a low scale application. Very few of them eventually grow into high scale.

BTW, self-attention, which is core to how most LLMs work, is itself a quadratic algorithm in both time and memory. So without limits in place (context length in this case), runaway complexity is unavoidable anyway.

Different algorithms have different time complexities. Every algorithm can't be optimized to linear time. This one can be.

It's a really odd thing to say that because some different algorithm has quadratic time complexity, this one should as well... when we know that it can run in linear time.

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented May 25, 2024

The prospect of using TGW to serve multiple users via the API is not a theoretical prospect. The API features have existed in TGW for a long time and they are used for more than just single-user use cases. I am personally powering a web application for multiple users with TGW, and have been for quite a while.

The TGW API doesn't even support concurrent inference, and AFAICT makes no effort to sanitize input in any way (beyond basic request parsing). I believe you when you say you are using it in a production webservice, but it's quite obviously not designed for that. Not to mention that it has a huge Python application attached to it that loads all kinds of stuff that the API doesn't need, and which could have any number of security implications that nobody ever audits for because 99% of users don't use it to power a production app.

It's a really odd thing to say that because some different algorithm has quadratic time complexity, this one should as well... when we know that it can run in linear time.

That is not at all what I said. What I wrote is that quadratic complexity during LLM inference is unavoidable because inference itself is quadratic in the context length. I didn't say DRY "should" be quadratic. However, the following two things are true:

  1. DRY isn't actually quadratic except for pathological inputs (where matches overlap). The same is true for quicksort, regular expression engines with lookahead, and many other algorithms that are widely deployed even though subquadratic alternatives exist (mergesort, finite automata).
  2. The linear-time implementation using the Z algorithm comes at a cost. That cost is complexity, and it creates other costs concerning maintainability, portability etc. It's not simply "free performance". The list conversion is, which is why I support merging that as soon as possible without reservations.

@belladoreai
Copy link
Copy Markdown
Contributor Author

The TGW API doesn't even support concurrent inference, and AFAICT makes no effort to sanitize input in any way (beyond basic request parsing). I believe you when you say you are using it in a production webservice, but it's quite obviously not designed for that. Not to mention that it has a huge Python application attached to it that loads all kinds of stuff that the API doesn't need, and which could have any number of security implications that nobody ever audits for because 99% of users don't use it to power a production app.

This is going way off topic, but...

I'm aware of the security implications. My web backend takes untrusted user input, tokenizes it, and constructs a prompt at token level. Before the web backend sends the prompt to TGW, the prompt tokens are decoded into a string. The reason why I wrote it like that is I need to stay below the context length limit, so I need to be aware of token count during the construction of the prompt. A side effect from this construction is that untrusted user input is effectively sanitized before fed into TGW. The user can only affect the prompt (not other generation parameters), and they can only affect the prompt with text that has a valid tokenization.

Secondly, most web applications are flaming piles of garbage in terms of security. Yes, I'm sure you could audit TGW and find loads of vulnerabilities, just like you find loads of vulnerabilities in every other application. But when you're running TGW in a docker container on a cloud host, and you're sanitizing user input on a separate web backend, the risk of something really bad happening is pretty small.

Third, when you say that TGW is "quite obviously not designed for that"... well, it's not obvious to me. As I said already, I'm not aware of any alternatives which would be better for a low scale web app. Can you name any? Also, TGW has absolutely no warnings like "only for single user purposes! do not use in production!". TGW provides an API and documents that API, and it just happens to be the best API available... but you are here telling me that TGW is "obviously not designed for that". Obvious how?

@belladoreai belladoreai mentioned this pull request May 25, 2024
1 task
@belladoreai
Copy link
Copy Markdown
Contributor Author

belladoreai commented May 25, 2024

How about spinning off the list conversion and overflow fix into a separate PR? That should be quick to review and merge, and then we can discuss the other change here.

Okay. @p-e-w can you please review: #6053

I also realized that as we apply the cap for max repetition length to prevent overflow error, we can apply the cap during the looping phase, so we never loop more than 1000 tokens backwards. That gives us a very good performance boost for the adversarial "worst case" inputs.

@Hunterius8
Copy link
Copy Markdown

@belladoreai @p-e-w

Tested it by just pasting part of a long article into the webui and generating more text. Context length was set to 21504. Massive improvements over the old version.

image

Went from about 10.5t/s at the maximum length previously, to 21.5t/s now.

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented May 28, 2024

@Hunterius8

Great to see real-world numbers! Could you test the much simpler PR #6053 for comparison?

@belladoreai
Copy link
Copy Markdown
Contributor Author

Closing in favor of #6053

@belladoreai
Copy link
Copy Markdown
Contributor Author

@p-e-w Now that DRY has finally been merged to main, I feel like maybe I can bring this up again without risking further delays... Are you interested in sequence breakers improvements? Regarding #6047 (comment)

I'm guessing that you wouldn't view the benefit/complexity tradeoff as favorable? I'm asking because I don't want to sink work into a black hole, as you put it. So if you're interested, I'll work on this, but if not, I'll work on something else.

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented Jun 22, 2024

@belladoreai

Sorry for the delay, I missed the notification somehow.

I do agree that improving the sequence breakers mechanism is worth exploring. That being said, I'm not completely convinced that there is a real problem here. Of course it is possible to construct examples where something suboptimal happens, as you have done in the linked comment. But sequence breakers, and DRY itself, are only heuristics. The system doesn't fall apart if a repetition isn't prevented in some very specific circumstances. Any change that increases complexity should provide demonstrably better results in non-constructed scenarios to be worth its weight. But I'm definitely open to discussing potential improvements.

p-e-w added a commit to p-e-w/mistral.rs that referenced this pull request Jul 29, 2024
Avoids quadratic runtime and potential DoS with adversarial inputs

Ref oobabooga/text-generation-webui#6047
EricLBuehler pushed a commit to EricLBuehler/mistral.rs that referenced this pull request Jul 29, 2024
* Silence bogus Clippy warning

Clippy's suggestion cannot be implemented because of borrowing issues

* Get rid of unnecessary type annotations

Interesting that Clippy doesn't catch this

* Store default sequence breakers in a slice

It's nicer when the length is not hardcoded

* Make default sequence breakers private

No need to leak this as it's not used elsewhere

* Limit match length

Avoids quadratic runtime and potential DoS with adversarial inputs

Ref oobabooga/text-generation-webui#6047

* "Fix" sequence breaker tokenization

Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result.

See LostRuins/koboldcpp#982 for a correct solution to this problem.
EricLBuehler added a commit to EricLBuehler/mistral.rs that referenced this pull request Aug 27, 2024
* Implement dry penalty

* Add dry sampling params to requests

* Handle it

* Clippy

* Review: "Implement DRY penalty" (#645)

* Silence bogus Clippy warning

Clippy's suggestion cannot be implemented because of borrowing issues

* Get rid of unnecessary type annotations

Interesting that Clippy doesn't catch this

* Store default sequence breakers in a slice

It's nicer when the length is not hardcoded

* Make default sequence breakers private

No need to leak this as it's not used elsewhere

* Limit match length

Avoids quadratic runtime and potential DoS with adversarial inputs

Ref oobabooga/text-generation-webui#6047

* "Fix" sequence breaker tokenization

Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result.

See LostRuins/koboldcpp#982 for a correct solution to this problem.

* Nicer

* Even better

* Complete merge

* Fix saturating sub

* Handle when no context

* Make context the entire sequence and refactor

* Remove slicing for all

* Fix the bug with penalty

Credit to @p-e-w for finding this!

Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>

* Add custom logits processor API (#702)

* Add custom logits processor api

* Typos

* Nicer interface and update example

* Fix doctest

* Update docs

* Update exports

* Add Gemma 2 PagedAttention support (#704)

* Add gemma2 paged attn support

* Non cuda support?

* Remove error

* It works

* Faster RmsNorm in gemma/gemma2 (#703)

* Fix bug in metal isq (#706)

* Support GGUF BF16 tensors (#691)

* Support GGUF bf16 tensors

* Fix loading of bf16 ggml tensor

* Fix dequant of bf16

* Use merged rev

* Softcapping, real batching + sliding window support for Flash Attention  (#707)

* Flash attention varlen kind of works

* Seems to work

* Now it's nice

* Sliding window support and clippy

* Remove warning

* Support smollm

* Update rev to match merged

* Remove some usages of 'pub' in models (#708)

* Support the Phi 3.5 V model (#710)

* Update image_seq_len

* Update the examples

* Format

* Implement the Phi 3.5 MoE model (#709)

* Copy the model

* Add most of it

* Add the blocksparse moe parts

* Clippy

* Fix mscales

* A batch of fixes

* Correctly cast it

* Handle isq on gate

* Even more progress

* Runs now

* Clippy

* Fix to use layernorm

* Remove unused

* Add docs

* Add more docs

* Apply review comments

* Update readme

---------

Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants