Skip to content

Commit 1a774fd

Browse files
committed
Change seeded generate test to use mixed batch
1 parent a0eebef commit 1a774fd

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

tests/samplers/test_seeded_generate.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm import SamplingParams
1313

1414
MODEL = "facebook/opt-125m"
15-
RANDOM_SEEDS = list(range(3))
15+
RANDOM_SEEDS = list(range(5))
1616

1717

1818
@pytest.fixture
@@ -37,7 +37,7 @@ def test_random_sample_with_seed(
3737
top_k=random.randint(5, 20),
3838
n=random.randint(1, 10),
3939
presence_penalty=random.randint(0, 1),
40-
max_tokens=4,
40+
max_tokens=8,
4141
ignore_eos=True,
4242
)
4343

@@ -46,23 +46,37 @@ def test_random_sample_with_seed(
4646
sampling_params_seed_2 = copy.deepcopy(sampling_params)
4747
sampling_params_seed_2.seed = 200
4848

49-
vllm_outputs_no_seed_1 = vllm_model.generate(example_prompts,
50-
sampling_params)
51-
vllm_outputs_seed_1_1 = vllm_model.generate(example_prompts,
52-
sampling_params_seed_1)
53-
vllm_outputs_seed_2_1 = vllm_model.generate(example_prompts,
54-
sampling_params_seed_2)
55-
vllm_outputs_no_seed_2 = vllm_model.generate(example_prompts,
56-
sampling_params)
57-
vllm_outputs_seed_1_2 = vllm_model.generate(example_prompts,
58-
sampling_params_seed_1)
59-
vllm_outputs_seed_2_2 = vllm_model.generate(example_prompts,
60-
sampling_params_seed_2)
61-
62-
for output_a, output_b in combinations(
63-
(vllm_outputs_no_seed_1, vllm_outputs_no_seed_2, vllm_outputs_seed_1_1,
64-
vllm_outputs_seed_2_1), 2):
65-
assert output_a != output_b
66-
67-
assert vllm_outputs_seed_1_1 == vllm_outputs_seed_1_2
68-
assert vllm_outputs_seed_2_1 == vllm_outputs_seed_2_2
49+
llm = vllm_model.model
50+
51+
for prompt in example_prompts:
52+
for params in (
53+
sampling_params,
54+
sampling_params_seed_1,
55+
sampling_params_seed_2,
56+
sampling_params,
57+
sampling_params_seed_1,
58+
sampling_params_seed_2,
59+
):
60+
llm._add_request(
61+
prompt=prompt,
62+
prompt_token_ids=None,
63+
sampling_params=params,
64+
)
65+
66+
results = llm._run_engine(use_tqdm=False)
67+
all_outputs = [[out.token_ids for out in output.outputs]
68+
for output in results]
69+
70+
for i in range(0, len(example_prompts), 6):
71+
outputs = all_outputs[i:i + 6]
72+
73+
# verify all non-seeded requests differ
74+
for output_a, output_b in combinations(
75+
(outputs[0], outputs[1], outputs[2], outputs[3]),
76+
2,
77+
):
78+
assert output_a != output_b
79+
80+
# verify requests with the same seed match
81+
assert outputs[1] == outputs[4]
82+
assert outputs[2] == outputs[5]

0 commit comments

Comments
 (0)