Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/e2e/test_spyre_max_new_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def test_output(model: ModelInfo, stop_last: bool, max_model_len: int,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=False)

vllm_sampling_params = [vllm_sampling_params_normal] * 3
vllm_sampling_params = [
vllm_sampling_params_normal.clone() for _ in range(3)
]
Copy link
Collaborator

@tjohnson31415 tjohnson31415 Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed now that all tests use GTI by default (before this PR, only tests using get_engine would use GTI). Having a copied reference instead of a .clone() meant that all sequences had the same GTI config (even with different prompts).

hf_max_new_tokens = [max_new_tokens_long] * 3

# stop last or first sequence in batch early
Expand Down
179 changes: 0 additions & 179 deletions tests/golden_token_injector.py

This file was deleted.

17 changes: 6 additions & 11 deletions tests/llm_cache.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Contains utilities for caching models (instantiated as vLLM endpoints)
across test cases, to speed up test runtime."""

import os
from typing import Callable, Generic, Optional, TypeVar

import pytest
from golden_token_injector import GoldenTokenInjector
from llm_cache_util import force_engine_shutdown
from spyre_util import (DecodeWarmupShapes, ModelInfo, RemoteOpenAIServer,
patch_environment)
from vllm import LLM, EngineArgs
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector

T = TypeVar("T")

## class definitions ##########################################
Expand Down Expand Up @@ -128,10 +128,11 @@ def get_cached_llm(
LLM(
model=model_name,
tokenizer=model_name,
revision=revision,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
revision=revision,
logits_processors=[GoldenTokenInjector],
),
)

Expand Down Expand Up @@ -179,12 +180,6 @@ def get_engine(
revision = None
model_name = model

# Register golden token injector if not disabled
disable_golden_token = \
bool(int(os.getenv("VLLM_SPYRE_TEST_DISABLE_GOLDEN_TOKEN", "0")))
logits_processors = [] if disable_golden_token else \
[GoldenTokenInjector]

# 🌶️🌶️🌶️
# Messing with the blocks and context length by either:
# - setting context < 512 tokens
Expand All @@ -201,11 +196,11 @@ def get_engine(
max_num_seqs_compiled = 1 << (max_num_seqs - 1).bit_length()
engine_args = EngineArgs(model=model_name,
tokenizer=model_name,
revision=revision,
max_model_len=max(max_model_len, 512),
max_num_seqs=max_num_seqs_compiled,
num_gpu_blocks_override=None,
revision=revision,
logits_processors=logits_processors)
logits_processors=[GoldenTokenInjector])
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

Expand Down
17 changes: 3 additions & 14 deletions tests/output_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,14 @@ def setup_golden_token(
model: ModelInfo,
sampling_params: Union[SamplingParams, list[SamplingParams]],
hf_outputs: list[dict[str, Any]],
) -> Union[SamplingParams, list[SamplingParams]]:

) -> list[SamplingParams]:
abs_tol = ISCLOSE_ABS_TOL_QUANTIZATION if model.is_quantized \
else ISCLOSE_ABS_TOL

if isinstance(sampling_params, SamplingParams):
# Single Sampling params case
hf = hf_outputs[0]
sampling_params.extra_args = {
"golden_token_injector": {
"expected_token_ids": hf['token_ids'],
"expected_logprobs": hf['logprobs'],
"error_threshold": abs_tol,
"label": "#0"
}
}
return sampling_params
# golden tokens injection is per request, so we clone SamplingParams
sampling_params = [sampling_params.clone() for _ in hf_outputs]

# Multiple sampling params case
assert len(sampling_params) == len(hf_outputs)
for idx, (param, hf) in enumerate(zip(sampling_params, hf_outputs)):
param.extra_args = {
Expand Down
Loading