Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/e2e/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,
tokenizer = spyre_model.get_tokenizer()
eos_id = tokenizer.eos_token_id

params1 = SamplingParams(min_tokens=19,
params1 = SamplingParams(min_tokens=10,
logit_bias={eos_id: 50},
seed=8780,
max_tokens=20)
Expand All @@ -271,8 +271,12 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,
output1 = spyre_model.generate(prompt, params1)[0]
output2 = spyre_model.generate(prompt, params2)[0]

assert len(output1.outputs[0].token_ids) >= 19
assert len(output2.outputs[0].token_ids) < 19
assert len(output1.outputs[0].token_ids) >= 10
# Logits bias should force eos token appears, then we check if
# after min tokens reached the logits processor is properly
# cleared.
assert len(output1.outputs[0].token_ids) < 20
assert len(output2.outputs[0].token_ids) < 10
Copy link
Collaborator

@tjohnson31415 tjohnson31415 Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If increase the eos_id logit bias to force it to be generated then we can assert on the exact output length, right?

    assert len(output1.outputs[0].token_ids) == 11
    assert len(output2.outputs[0].token_ids) == 1

(the values for those asserts may be off-by-one depending on how EOS is tracked in the outputs 😅)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NP, PTAL



def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch,
Expand Down
3 changes: 3 additions & 0 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def update_states(self, scheduler_output: SchedulerOutput):
# of logitprocs. Refactor so that we can batch removals to the
# `input_batch`
self.input_batch.refresh_metadata()
else:
# Due to logits processor we need to refresh metadata at each step
self.input_batch.refresh_metadata()

def _get_prompt_logprobs_dict(
self,
Expand Down