-
Notifications
You must be signed in to change notification settings - Fork 31
fix: logits processor state at each step #544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. Or this can be done with Now you are good to go 🚀 |
tests/e2e/test_sampling_params.py
Outdated
| # after min tokens reached the logits processor is properly | ||
| # cleared. | ||
| assert len(output1.outputs[0].token_ids) < 20 | ||
| assert len(output2.outputs[0].token_ids) < 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If increase the eos_id logit bias to force it to be generated then we can assert on the exact output length, right?
assert len(output1.outputs[0].token_ids) == 11
assert len(output2.outputs[0].token_ids) == 1
(the values for those asserts may be off-by-one depending on how EOS is tracked in the outputs 😅)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NP, PTAL
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
tjohnson31415
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
|
bot:test |
…545) # Description The MinTokensLogitsProcessor needs to get a `batch_update` at each step to detect when enough tokens have been generated. The `LogitProcessorWrapper` copied the typical logic of skipping updates when batch_update is None, but this meant that min tokens would not get the needed call to `update_state`. The fix here is to always call `update_state` on each of the wrapped logitsprocs in the batch, with a some extra code to not call `update_state` for a particular index more than once. ## Related Issues Follow up to #544 which fixed the behavior for static batching Cherry-picked improvement to test_sampling_params.py from #536 --------- Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Wallas Santos <[email protected]>
Description
This PR fixes the update of logits processors that need to be updated at each engine step. To validate the change, I updated the existing test for min tokens where we can identify the wrong behaviour. Note: the bug is reproducible in both CB and SB.