Skip to content

Commit eae482f

Browse files
afeldman-nmrichardsliu
authored andcommitted
[V1] Detokenizer: Respect Stop Tokens + not include_stop_str_in_output (vllm-project#14624)
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Richard Liu <[email protected]>
1 parent bc0613b commit eae482f

File tree

4 files changed

+215
-18
lines changed

4 files changed

+215
-18
lines changed

tests/v1/engine/test_output_processor.py

Lines changed: 170 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -470,22 +470,184 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
470470
assert not output_processor.has_unfinished_requests()
471471

472472

473+
@pytest.mark.parametrize(
474+
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
475+
[(False, "stop_token_ids", False, None),
476+
(True, "stop_token_ids", False, None),
477+
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
478+
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
479+
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
480+
(False, "eos_token_id", True, None)])
481+
def test_stop_token(include_stop_str_in_output: bool,
482+
num_sample_logprobs: Optional[int], stop_token_type: str,
483+
ignore_eos: bool, dummy_test_vectors):
484+
"""Test output processor EOS/stop token handling.
485+
486+
Send mock engine core request to mock engine core and pass core outputs
487+
to output processor. Validate output processor tokens, text and
488+
(if enabled) sample logprobs. Batch-size one.
489+
490+
The test emulates a scenario where a model outputs text tokens followed
491+
by two identical control tokens:
492+
<token><token>...<token><control><control>
493+
494+
If EOS is under test, the control tokens are EOS; otherwise, they are
495+
some other token id.
496+
497+
Test behavior:
498+
499+
* If EOS is under test and `ignore_eos=True`, the detokenized string
500+
should be <token><token>...<token><control><control> and the finish
501+
reason should be "length" (i.e. no stop occurs)
502+
503+
* else, if `include_stop_str_in_output==True`, the detokenized
504+
string should be <token><token>...<token><control> and the finish
505+
reason should be "stop" (i.e. first control token causes stop
506+
and is represented in output text)
507+
508+
* else, the detokenized string should be
509+
<token><token>...<token> and the finish reason should be "stop"
510+
(i.e. first control token causes stop but is not represented
511+
in output text.)
512+
513+
Note: some test details are tuned for meta-llama/Llama-3.2-1B,
514+
another model should work only if the test is modified.
515+
516+
Args:
517+
include_stop_str_in_output: stop token str appears in output text
518+
num_sample_logprobs: number of sample logprobs (`None` for no logprobs)
519+
stop_token_type: "eos_token_id" for EOS, "stop_token_ids" for stop token
520+
ignore_eos: if True, EOS stops are disabled
521+
dummy_test_vectors: dummy engine core outputs and other data structures
522+
"""
523+
model_id = dummy_test_vectors.tokenizer.name_or_path
524+
if model_id != 'meta-llama/Llama-3.2-1B':
525+
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
526+
f"{model_id} is in use.")
527+
do_logprobs = num_sample_logprobs is not None
528+
# EOS under test; if False, stop_token_ids under test
529+
is_eos_test = stop_token_type == "eos_token_id"
530+
# EOS under test but ignore_eos enabled
531+
is_eos_ignore_test = is_eos_test and ignore_eos
532+
eos_token_id = (
533+
dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None
534+
) # '<|end_of_text|>'
535+
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
536+
537+
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
538+
log_stats=False)
539+
# Dummy engine core outputs, with control tokens suffixed to test stops
540+
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
541+
assert suffix_token is not None and isinstance(suffix_token[0], int)
542+
generation_string = dummy_test_vectors.generation_strings[0]
543+
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
544+
2 * suffix_token)
545+
if do_logprobs:
546+
generation_logprobs = (
547+
dummy_test_vectors.generation_logprobs[0] +
548+
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
549+
prompt_string = dummy_test_vectors.prompt_strings[0]
550+
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
551+
engine_core = MockEngineCore(
552+
tokens_list=[generation_tokens],
553+
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
554+
prompt_logprobs_raw=None,
555+
eos_token_id=eos_token_id,
556+
stop_token_ids=stop_token_ids,
557+
ignore_eos=ignore_eos)
558+
559+
# Make request.
560+
request_id = "request-0"
561+
request = EngineCoreRequest(
562+
request_id=request_id,
563+
prompt=prompt_string,
564+
prompt_token_ids=prompt_tokens,
565+
arrival_time=0,
566+
mm_inputs=None,
567+
mm_hashes=None,
568+
mm_placeholders=None,
569+
eos_token_id=eos_token_id,
570+
lora_request=None,
571+
sampling_params=SamplingParams(
572+
skip_special_tokens=False,
573+
spaces_between_special_tokens=False,
574+
output_kind=RequestOutputKind.DELTA,
575+
stop=[],
576+
stop_token_ids=stop_token_ids,
577+
include_stop_str_in_output=include_stop_str_in_output,
578+
logprobs=num_sample_logprobs,
579+
prompt_logprobs=None,
580+
ignore_eos=ignore_eos,
581+
))
582+
583+
# Add request to the detokenizer.
584+
output_processor.add_request(request)
585+
586+
# Loop over engine core steps; run output processor
587+
gen_string = ""
588+
gen_tokens = []
589+
gen_logprobs = []
590+
while True:
591+
# Mock output from the EngineCore.
592+
outputs = engine_core.get_outputs()
593+
if len(outputs) == 0:
594+
break
595+
596+
# Step the Detokenizer.
597+
processed_outputs = output_processor.process_outputs(outputs)
598+
request_outputs = processed_outputs.request_outputs
599+
assert len(request_outputs) == 1
600+
# Stop token does not rely on abort
601+
assert not processed_outputs.reqs_to_abort
602+
603+
# Update tracking.
604+
request_output = request_outputs[0]
605+
if request_output.finished:
606+
finish_reason = ("length" if is_eos_ignore_test else "stop")
607+
assert request_output.outputs[0].finish_reason == finish_reason
608+
609+
gen_string += request_output.outputs[0].text
610+
gen_tokens.extend(request_output.outputs[0].token_ids)
611+
if do_logprobs:
612+
gen_logprobs.extend(request_output.outputs[0].logprobs)
613+
614+
# Validate generated text
615+
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
616+
if is_eos_ignore_test:
617+
# Length-based stop; expect full string
618+
ref_str = generation_string + 2 * control_token
619+
elif include_stop_str_in_output:
620+
# Stop token triggered; include in output
621+
ref_str = generation_string + control_token
622+
else:
623+
# Stop token triggered but not in output
624+
ref_str = generation_string
625+
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")
626+
627+
if do_logprobs:
628+
# Validate number of sample logprobs
629+
num_tokens = len(gen_tokens)
630+
num_logprobs = len(gen_logprobs)
631+
assert num_tokens == num_logprobs, (
632+
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")
633+
634+
# Check requests are finished
635+
assert output_processor.get_num_unfinished_requests() == 0
636+
assert not output_processor.has_unfinished_requests()
637+
638+
473639
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
474640
@pytest.mark.parametrize("num_sample_logprobs",
475641
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
476-
@pytest.mark.parametrize("num_prompt_logprobs",
477-
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
478642
def test_stop_string(include_stop_str_in_output: bool,
479-
num_sample_logprobs: Optional[int],
480-
num_prompt_logprobs: Optional[int], dummy_test_vectors):
643+
num_sample_logprobs: Optional[int], dummy_test_vectors):
481644
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
482645
log_stats=False)
483646
engine_core = MockEngineCore(
484647
tokens_list=dummy_test_vectors.generation_tokens,
485648
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
486649
if num_sample_logprobs else None,
487-
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
488-
if num_prompt_logprobs else None)
650+
prompt_logprobs_raw=None)
489651

490652
# Make N requests.
491653
request_id_list = [
@@ -510,7 +672,7 @@ def test_stop_string(include_stop_str_in_output: bool,
510672
stop=STOP_STRINGS,
511673
include_stop_str_in_output=include_stop_str_in_output,
512674
logprobs=num_sample_logprobs,
513-
prompt_logprobs=num_prompt_logprobs,
675+
prompt_logprobs=None,
514676
)) for idx, (prompt, prompt_tokens) in enumerate(
515677
zip(dummy_test_vectors.prompt_strings,
516678
dummy_test_vectors.prompt_tokens))
@@ -594,8 +756,7 @@ def test_stop_string(include_stop_str_in_output: bool,
594756
# Confirmed tracked logprobs match what we expect
595757
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
596758
gen_cumulative_logprobs, dummy_test_vectors,
597-
request_id_list, num_sample_logprobs,
598-
num_prompt_logprobs)
759+
request_id_list, num_sample_logprobs, None)
599760

600761
assert output_processor.get_num_unfinished_requests() == 0
601762
assert not output_processor.has_unfinished_requests()

tests/v1/engine/utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# Number of prompt logprobs to request when testing prompt logprobs
2121
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
2222

23-
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
23+
TOKENIZER_NAME = "meta-llama/Llama-3.2-1B"
2424

2525
FULL_STRINGS = [
2626
"My name is Robert from Neural Magic and I love working on vLLM so much!",
@@ -330,13 +330,21 @@ def __init__(
330330
# each matrix has dimensions
331331
# (num prompt toks) x (num prompt logprobs+1)
332332
prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None,
333+
eos_token_id: Optional[int] = None,
334+
stop_token_ids: Optional[list[int]] = None,
335+
ignore_eos: bool = False,
333336
) -> None:
337+
self.num_requests = len(tokens_list)
334338
self.tokens_list = tokens_list
335339
self.current_idx = 0
336340
self.generated_logprobs_raw = generated_logprobs_raw
337341
self.do_logprobs = generated_logprobs_raw is not None
338342
self.prompt_logprobs_raw = prompt_logprobs_raw
339343
self.do_prompt_logprobs = prompt_logprobs_raw is not None
344+
self.request_finished = [False for _ in range(self.num_requests)]
345+
self.eos_token_id = eos_token_id
346+
self.stop_token_ids = stop_token_ids
347+
self.ignore_eos = ignore_eos
340348

341349
def get_outputs(self) -> list[EngineCoreOutput]:
342350
do_logprobs = self.do_logprobs
@@ -345,7 +353,7 @@ def get_outputs(self) -> list[EngineCoreOutput]:
345353

346354
outputs = []
347355
for req_idx, token_ids in enumerate(self.tokens_list):
348-
if len(token_ids) > token_idx:
356+
if not self.request_finished[req_idx]:
349357
if do_logprobs:
350358
assert self.generated_logprobs_raw is not None
351359
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
@@ -365,14 +373,23 @@ def get_outputs(self) -> list[EngineCoreOutput]:
365373
prompt_logprobs = None
366374
else:
367375
prompt_logprobs = None
376+
new_token_id = token_ids[token_idx]
368377
output = EngineCoreOutput(
369378
request_id=f"request-{req_idx}",
370-
new_token_ids=[token_ids[token_idx]],
379+
new_token_ids=[new_token_id],
371380
new_logprobs=logprobs,
372381
new_prompt_logprobs_tensors=prompt_logprobs,
373382
)
374383
if token_idx == len(token_ids) - 1:
384+
output.finish_reason = FinishReason.LENGTH
385+
self.request_finished[req_idx] = True
386+
if not self.ignore_eos and new_token_id == self.eos_token_id:
375387
output.finish_reason = FinishReason.STOP
388+
self.request_finished[req_idx] = True
389+
if new_token_id in (self.stop_token_ids or ()):
390+
output.finish_reason = FinishReason.STOP
391+
output.stop_reason = new_token_id
392+
self.request_finished[req_idx] = True
376393
outputs.append(output)
377394

378395
self.current_idx += 1

vllm/v1/engine/detokenizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,31 @@ def from_new_request(
8888
stop_buffer_length=stop_buffer_length,
8989
)
9090

91-
def update(self, new_token_ids: list[int]) -> Optional[str]:
91+
def update(self, new_token_ids: list[int],
92+
stop_terminated: bool) -> Optional[str]:
9293
"""
9394
Update RequestState for the request_id by:
9495
1) Detokenize the new token ids incrementally.
9596
2) Evaluate stop criteria.
9697
9798
Return matched stop string or None.
9899
"""
99-
100+
if not new_token_ids:
101+
# Skip detokenization if no new token ids
102+
return None
100103
if self.tokenizer is None:
104+
# Skip detokenization if no tokenizer
101105
self.token_ids.extend(new_token_ids)
102106
return None
103107

108+
if stop_terminated and not self.include_stop_str_in_output:
109+
# If stop-terminated, exclude last token from detokenization
110+
# based on include_stop_str_in_output parameter.
111+
skipped_stop_token_id = new_token_ids[-1]
112+
new_token_ids = new_token_ids[:-1]
113+
else:
114+
skipped_stop_token_id = None
115+
104116
# 1) Detokenize the new token ids incrementally.
105117
# TODO(woosuk): This method becomes very inefficient when the number of
106118
# new_token_ids is more than 1. We need to optimize this.
@@ -127,7 +139,14 @@ def update(self, new_token_ids: list[int]) -> Optional[str]:
127139

128140
self.output_text += decoded_text
129141

130-
# 2) Evaluate stop criteria.
142+
if stop_terminated:
143+
if skipped_stop_token_id is not None:
144+
# Cleanup after skipping detokenization
145+
self.token_ids.append(skipped_stop_token_id)
146+
# Stop token triggered; skip stop string check
147+
return None
148+
149+
# 2) Evaluate stop strings.
131150
stop_string = None
132151
if self.stop:
133152
stop = StopChecker.check_stop_strings(

vllm/v1/engine/output_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ def process_outputs(
299299
# in the EngineCore.
300300
req_state.is_prefilling = not new_token_ids
301301

302-
# 2) Detokenize the token ids into text and check for stop
303-
# strings.
304-
stop_string = req_state.detokenizer.update(new_token_ids)
302+
# 2) Detokenize the token ids into text and perform stop checks.
303+
stop_string = req_state.detokenizer.update(
304+
new_token_ids, finish_reason == FinishReason.STOP)
305305
if stop_string and finish_reason != FinishReason.STOP:
306306
finish_reason = FinishReason.STOP
307307
stop_reason = stop_string

0 commit comments

Comments
 (0)