Skip to content

Add custom logits processor API#702

Merged
EricLBuehler merged 5 commits intomasterfrom
custom_logits_processor_api
Aug 21, 2024
Merged

Add custom logits processor API#702
EricLBuehler merged 5 commits intomasterfrom
custom_logits_processor_api

Conversation

@EricLBuehler
Copy link
Copy Markdown
Owner

New sampling order:

  1. Apply penalties from sampling_params
  2. ⭐ Apply these custom logits processors sequentially
  3. Apply temperature and softmax
  4. Sample the next token (topk, topp, minp, etc)

@Dan-wanna-M would this API work? You can find an example of using it here. Basically, you provide a Vec of Arc<dyn Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync>, which are applied.

Please let me know if this works for you or what changes would be necessary, and I can merge this PR. kbnf looks super cool!

@EricLBuehler EricLBuehler added new feature New feature or request processing Processing related to the model labels Aug 21, 2024
@github-actions
Copy link
Copy Markdown

Code Metrics Report
  ===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                2           35           28            0            7
 Dockerfile              1           34           25            0            9
 Happy                   1          442          369            0           73
 JSON                   11          102          101            0            1
 Python                 46         2018         1718           62          238
 TOML                   20          617          544           11           62
 YAML                    1            9            8            1            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       4            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          196          169            1           26
 (Total)                            273          201           32           40
-------------------------------------------------------------------------------
 Markdown               28         1960            0         1481          479
 |- BASH                 5          101           98            0            3
 |- JSON                 1           12           12            0            0
 |- Python               5           92           82            0           10
 |- Rust                 6          408          365           19           24
 |- TOML                 2           75           63            0           12
 (Total)                           2648          620         1500          528
-------------------------------------------------------------------------------
 Rust                  195        59957        54396         1100         4461
 |- Markdown           100          906           13          843           50
 (Total)                          60863        54409         1943         4511
===============================================================================
 Total                 310        65174        57189         2655         5330
===============================================================================
  

@Dan-wanna-M
Copy link
Copy Markdown

@EricLBuehler One thing I am pondering is that kbnf is stateful, which means to use Arc<dyn Fn()> I need to use Mutex<T>. This requirement by itself is fine, but I am wondering its implications. For example, if logits processors will be directly shared across threads and executed in parallel, then some very bad stuff probably will happen.

@EricLBuehler
Copy link
Copy Markdown
Owner Author

@EricLBuehler One thing I am pondering is that kbnf is stateful, which means to use Arc<dyn Fn()> I need to use Mutex. This requirement by itself is fine, but I am wondering its implications. For example, if logits processors will be directly shared across threads and executed in parallel, then some very bad stuff probably will happen.

@Dan-wanna-M yes, we parallelize the sampling process. Perhaps we could change the API to accept some generic type which has some method to do the logits processing? Then, you could contain the state in that type in Mutex.

@EricLBuehler
Copy link
Copy Markdown
Owner Author

@Dan-wanna-M I just updated it so now the API is:

pub trait CustomLogitsProcessor: Send + Sync {
    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}

In the example, we have a custom struct:

struct ThresholdLogitsProcessor {
    threshold: f64,
}

impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(self.threshold)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}

Would this work if you could put shared state into the custom struct w/ a mutex?

@Dan-wanna-M
Copy link
Copy Markdown

@Dan-wanna-M I just updated it so now the API is:

pub trait CustomLogitsProcessor: Send + Sync {
    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}

In the example, we have a custom struct:

struct ThresholdLogitsProcessor {
    threshold: f64,
}

impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(self.threshold)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}

Would this work if you could put shared state into the custom struct w/ a mutex?

I spent some time rethinking it and I think as long as one generation only uses one logits processor then it should work(So a batch of N generation will need N separate logits processor)

@EricLBuehler
Copy link
Copy Markdown
Owner Author

@Dan-wanna-M I just updated it so now the API is:

pub trait CustomLogitsProcessor: Send + Sync {
    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}

In the example, we have a custom struct:

struct ThresholdLogitsProcessor {
    threshold: f64,
}

impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(self.threshold)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}

Would this work if you could put shared state into the custom struct w/ a mutex?

I spent some time rethinking it and I think as long as one generation only uses one logits processor then it should work(So a batch of N generation will need N separate logits processor)

Sounds good, and this API implements CustomLogitsProcessor for the corresponding closure type! I'll merge this if there are no other problems.

@Dan-wanna-M
Copy link
Copy Markdown

@Dan-wanna-M I just updated it so now the API is:

pub trait CustomLogitsProcessor: Send + Sync {
    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}

In the example, we have a custom struct:

struct ThresholdLogitsProcessor {
    threshold: f64,
}

impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(self.threshold)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}

Would this work if you could put shared state into the custom struct w/ a mutex?

I spent some time rethinking it and I think as long as one generation only uses one logits processor then it should work(So a batch of N generation will need N separate logits processor)

Sounds good, and this API implements CustomLogitsProcessor for the corresponding closure type! I'll merge this if there are no other problems.

Yes it should be mergeable now

@EricLBuehler EricLBuehler merged commit 49e3eb9 into master Aug 21, 2024
@EricLBuehler EricLBuehler deleted the custom_logits_processor_api branch August 21, 2024 03:27
@EricLBuehler
Copy link
Copy Markdown
Owner Author

Sounds good @Dan-wanna-M! Just merged it.

EricLBuehler added a commit that referenced this pull request Aug 24, 2024
* Add custom logits processor api

* Typos

* Nicer interface and update example

* Fix doctest

* Update docs
EricLBuehler added a commit that referenced this pull request Aug 27, 2024
* Implement dry penalty

* Add dry sampling params to requests

* Handle it

* Clippy

* Review: "Implement DRY penalty" (#645)

* Silence bogus Clippy warning

Clippy's suggestion cannot be implemented because of borrowing issues

* Get rid of unnecessary type annotations

Interesting that Clippy doesn't catch this

* Store default sequence breakers in a slice

It's nicer when the length is not hardcoded

* Make default sequence breakers private

No need to leak this as it's not used elsewhere

* Limit match length

Avoids quadratic runtime and potential DoS with adversarial inputs

Ref oobabooga/text-generation-webui#6047

* "Fix" sequence breaker tokenization

Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result.

See LostRuins/koboldcpp#982 for a correct solution to this problem.

* Nicer

* Even better

* Complete merge

* Fix saturating sub

* Handle when no context

* Make context the entire sequence and refactor

* Remove slicing for all

* Fix the bug with penalty

Credit to @p-e-w for finding this!

Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>

* Add custom logits processor API (#702)

* Add custom logits processor api

* Typos

* Nicer interface and update example

* Fix doctest

* Update docs

* Update exports

* Add Gemma 2 PagedAttention support (#704)

* Add gemma2 paged attn support

* Non cuda support?

* Remove error

* It works

* Faster RmsNorm in gemma/gemma2 (#703)

* Fix bug in metal isq (#706)

* Support GGUF BF16 tensors (#691)

* Support GGUF bf16 tensors

* Fix loading of bf16 ggml tensor

* Fix dequant of bf16

* Use merged rev

* Softcapping, real batching + sliding window support for Flash Attention  (#707)

* Flash attention varlen kind of works

* Seems to work

* Now it's nice

* Sliding window support and clippy

* Remove warning

* Support smollm

* Update rev to match merged

* Remove some usages of 'pub' in models (#708)

* Support the Phi 3.5 V model (#710)

* Update image_seq_len

* Update the examples

* Format

* Implement the Phi 3.5 MoE model (#709)

* Copy the model

* Add most of it

* Add the blocksparse moe parts

* Clippy

* Fix mscales

* A batch of fixes

* Correctly cast it

* Handle isq on gate

* Even more progress

* Runs now

* Clippy

* Fix to use layernorm

* Remove unused

* Add docs

* Add more docs

* Apply review comments

* Update readme

---------

Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new feature New feature or request processing Processing related to the model

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants