Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 170 additions & 9 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,22 +470,184 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
assert not output_processor.has_unfinished_requests()


@pytest.mark.parametrize(
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
[(False, "stop_token_ids", False, None),
(True, "stop_token_ids", False, None),
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
(False, "eos_token_id", True, None)])
def test_stop_token(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], stop_token_type: str,
ignore_eos: bool, dummy_test_vectors):
"""Test output processor EOS/stop token handling.

Send mock engine core request to mock engine core and pass core outputs
to output processor. Validate output processor tokens, text and
(if enabled) sample logprobs. Batch-size one.

The test emulates a scenario where a model outputs text tokens followed
by two identical control tokens:
<token><token>...<token><control><control>

If EOS is under test, the control tokens are EOS; otherwise, they are
some other token id.

Test behavior:

* If EOS is under test and `ignore_eos=True`, the detokenized string
should be <token><token>...<token><control><control> and the finish
reason should be "length" (i.e. no stop occurs)

* else, if `include_stop_str_in_output==True`, the detokenized
string should be <token><token>...<token><control> and the finish
reason should be "stop" (i.e. first control token causes stop
and is represented in output text)

* else, the detokenized string should be
<token><token>...<token> and the finish reason should be "stop"
(i.e. first control token causes stop but is not represented
in output text.)

Note: some test details are tuned for meta-llama/Llama-3.2-1B,
another model should work only if the test is modified.

Args:
include_stop_str_in_output: stop token str appears in output text
num_sample_logprobs: number of sample logprobs (`None` for no logprobs)
stop_token_type: "eos_token_id" for EOS, "stop_token_ids" for stop token
ignore_eos: if True, EOS stops are disabled
dummy_test_vectors: dummy engine core outputs and other data structures
"""
model_id = dummy_test_vectors.tokenizer.name_or_path
if model_id != 'meta-llama/Llama-3.2-1B':
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
f"{model_id} is in use.")
do_logprobs = num_sample_logprobs is not None
# EOS under test; if False, stop_token_ids under test
is_eos_test = stop_token_type == "eos_token_id"
# EOS under test but ignore_eos enabled
is_eos_ignore_test = is_eos_test and ignore_eos
eos_token_id = (
dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None
) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'

output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
assert suffix_token is not None and isinstance(suffix_token[0], int)
generation_string = dummy_test_vectors.generation_strings[0]
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
2 * suffix_token)
if do_logprobs:
generation_logprobs = (
dummy_test_vectors.generation_logprobs[0] +
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos)

# Make request.
request_id = "request-0"
request = EngineCoreRequest(
request_id=request_id,
prompt=prompt_string,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=eos_token_id,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=[],
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
))

# Add request to the detokenizer.
output_processor.add_request(request)

# Loop over engine core steps; run output processor
gen_string = ""
gen_tokens = []
gen_logprobs = []
while True:
# Mock output from the EngineCore.
outputs = engine_core.get_outputs()
if len(outputs) == 0:
break

# Step the Detokenizer.
processed_outputs = output_processor.process_outputs(outputs)
request_outputs = processed_outputs.request_outputs
assert len(request_outputs) == 1
# Stop token does not rely on abort
assert not processed_outputs.reqs_to_abort

# Update tracking.
request_output = request_outputs[0]
if request_output.finished:
finish_reason = ("length" if is_eos_ignore_test else "stop")
assert request_output.outputs[0].finish_reason == finish_reason

gen_string += request_output.outputs[0].text
gen_tokens.extend(request_output.outputs[0].token_ids)
if do_logprobs:
gen_logprobs.extend(request_output.outputs[0].logprobs)

# Validate generated text
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
if is_eos_ignore_test:
# Length-based stop; expect full string
ref_str = generation_string + 2 * control_token
elif include_stop_str_in_output:
# Stop token triggered; include in output
ref_str = generation_string + control_token
else:
# Stop token triggered but not in output
ref_str = generation_string
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")

if do_logprobs:
# Validate number of sample logprobs
num_tokens = len(gen_tokens)
num_logprobs = len(gen_logprobs)
assert num_tokens == num_logprobs, (
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")

# Check requests are finished
assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()


@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.parametrize("num_sample_logprobs",
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs",
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], dummy_test_vectors):
num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs else None,
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
if num_prompt_logprobs else None)
prompt_logprobs_raw=None)

# Make N requests.
request_id_list = [
Expand All @@ -510,7 +672,7 @@ def test_stop_string(include_stop_str_in_output: bool,
stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
prompt_logprobs=None,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
Expand Down Expand Up @@ -594,8 +756,7 @@ def test_stop_string(include_stop_str_in_output: bool,
# Confirmed tracked logprobs match what we expect
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
gen_cumulative_logprobs, dummy_test_vectors,
request_id_list, num_sample_logprobs,
num_prompt_logprobs)
request_id_list, num_sample_logprobs, None)

assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()
Expand Down
23 changes: 20 additions & 3 deletions tests/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Number of prompt logprobs to request when testing prompt logprobs
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7

TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
TOKENIZER_NAME = "meta-llama/Llama-3.2-1B"

FULL_STRINGS = [
"My name is Robert from Neural Magic and I love working on vLLM so much!",
Expand Down Expand Up @@ -330,13 +330,21 @@ def __init__(
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None,
eos_token_id: Optional[int] = None,
stop_token_ids: Optional[list[int]] = None,
ignore_eos: bool = False,
) -> None:
self.num_requests = len(tokens_list)
self.tokens_list = tokens_list
self.current_idx = 0
self.generated_logprobs_raw = generated_logprobs_raw
self.do_logprobs = generated_logprobs_raw is not None
self.prompt_logprobs_raw = prompt_logprobs_raw
self.do_prompt_logprobs = prompt_logprobs_raw is not None
self.request_finished = [False for _ in range(self.num_requests)]
self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos

def get_outputs(self) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs
Expand All @@ -345,7 +353,7 @@ def get_outputs(self) -> list[EngineCoreOutput]:

outputs = []
for req_idx, token_ids in enumerate(self.tokens_list):
if len(token_ids) > token_idx:
if not self.request_finished[req_idx]:
if do_logprobs:
assert self.generated_logprobs_raw is not None
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
Expand All @@ -365,14 +373,23 @@ def get_outputs(self) -> list[EngineCoreOutput]:
prompt_logprobs = None
else:
prompt_logprobs = None
new_token_id = token_ids[token_idx]
output = EngineCoreOutput(
request_id=f"request-{req_idx}",
new_token_ids=[token_ids[token_idx]],
new_token_ids=[new_token_id],
new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs,
)
if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.LENGTH
self.request_finished[req_idx] = True
if not self.ignore_eos and new_token_id == self.eos_token_id:
output.finish_reason = FinishReason.STOP
self.request_finished[req_idx] = True
if new_token_id in (self.stop_token_ids or ()):
output.finish_reason = FinishReason.STOP
output.stop_reason = new_token_id
self.request_finished[req_idx] = True
outputs.append(output)

self.current_idx += 1
Expand Down
32 changes: 29 additions & 3 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class IncrementalDetokenizer:
read_offset: int = 0

# Parameters for detokenization
eos_token_id: Optional[int] = None
stop_token_ids: Optional[list[int]] = None
ignore_eos: bool = False
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True

Expand Down Expand Up @@ -86,20 +89,35 @@ def from_new_request(
prompt_len=len(request.prompt_token_ids),
tokenizer=tokenizer,
stop_buffer_length=stop_buffer_length,
stop_token_ids=request.sampling_params.stop_token_ids,
ignore_eos=request.sampling_params.ignore_eos,
eos_token_id=request.eos_token_id,
)

def update(self, new_token_ids: list[int]) -> Optional[str]:
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Evaluate stop criteria.

Return matched stop string or None.
"""

if self.tokenizer is None:
# Skip detokenization if no tokenizer
self.token_ids.extend(new_token_ids)
return None
if not new_token_ids:
# Skip detokenization if no new token ids
return None

if stop_terminated and not self.include_stop_str_in_output:
# If stop-terminated, exclude last token from detokenization
# based on include_stop_str_in_output parameter.
stop_token_id = new_token_ids[-1]
new_token_ids = new_token_ids[:-1]
else:
stop_token_id = None

# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
Expand Down Expand Up @@ -127,7 +145,15 @@ def update(self, new_token_ids: list[int]) -> Optional[str]:

self.output_text += decoded_text

# 2) Evaluate stop criteria.
if stop_terminated:
if not self.include_stop_str_in_output:
# Cleanup after skipping detokenization
assert stop_token_id is not None
self.token_ids.append(stop_token_id)
# Stop token triggered; skip stop string check
return None

# 2) Evaluate stop strings.
stop_string = None
if self.stop:
stop = StopChecker.check_stop_strings(
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def process_outputs(
# in the EngineCore.
req_state.is_prefilling = not new_token_ids

# 2) Detokenize the token ids into text and check for stop
# strings.
stop_string = req_state.detokenizer.update(new_token_ids)
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP:
finish_reason = FinishReason.STOP
stop_reason = stop_string
Expand Down
1 change: 1 addition & 0 deletions vllm/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ def _prev_minor_version_was(version_str):
return True

# Note - this won't do the right thing when we release 1.0!
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"