1212from vllm import SamplingParams
1313
1414MODEL = "facebook/opt-125m"
15- RANDOM_SEEDS = list (range (3 ))
15+ RANDOM_SEEDS = list (range (5 ))
1616
1717
1818@pytest .fixture
@@ -37,7 +37,7 @@ def test_random_sample_with_seed(
3737 top_k = random .randint (5 , 20 ),
3838 n = random .randint (1 , 10 ),
3939 presence_penalty = random .randint (0 , 1 ),
40- max_tokens = 4 ,
40+ max_tokens = 8 ,
4141 ignore_eos = True ,
4242 )
4343
@@ -46,23 +46,37 @@ def test_random_sample_with_seed(
4646 sampling_params_seed_2 = copy .deepcopy (sampling_params )
4747 sampling_params_seed_2 .seed = 200
4848
49- vllm_outputs_no_seed_1 = vllm_model .generate (example_prompts ,
50- sampling_params )
51- vllm_outputs_seed_1_1 = vllm_model .generate (example_prompts ,
52- sampling_params_seed_1 )
53- vllm_outputs_seed_2_1 = vllm_model .generate (example_prompts ,
54- sampling_params_seed_2 )
55- vllm_outputs_no_seed_2 = vllm_model .generate (example_prompts ,
56- sampling_params )
57- vllm_outputs_seed_1_2 = vllm_model .generate (example_prompts ,
58- sampling_params_seed_1 )
59- vllm_outputs_seed_2_2 = vllm_model .generate (example_prompts ,
60- sampling_params_seed_2 )
61-
62- for output_a , output_b in combinations (
63- (vllm_outputs_no_seed_1 , vllm_outputs_no_seed_2 , vllm_outputs_seed_1_1 ,
64- vllm_outputs_seed_2_1 ), 2 ):
65- assert output_a != output_b
66-
67- assert vllm_outputs_seed_1_1 == vllm_outputs_seed_1_2
68- assert vllm_outputs_seed_2_1 == vllm_outputs_seed_2_2
49+ llm = vllm_model .model
50+
51+ for prompt in example_prompts :
52+ for params in (
53+ sampling_params ,
54+ sampling_params_seed_1 ,
55+ sampling_params_seed_2 ,
56+ sampling_params ,
57+ sampling_params_seed_1 ,
58+ sampling_params_seed_2 ,
59+ ):
60+ llm ._add_request (
61+ prompt = prompt ,
62+ prompt_token_ids = None ,
63+ sampling_params = params ,
64+ )
65+
66+ results = llm ._run_engine (use_tqdm = False )
67+ all_outputs = [[out .token_ids for out in output .outputs ]
68+ for output in results ]
69+
70+ for i in range (0 , len (example_prompts ), 6 ):
71+ outputs = all_outputs [i :i + 6 ]
72+
73+ # verify all non-seeded requests differ
74+ for output_a , output_b in combinations (
75+ (outputs [0 ], outputs [1 ], outputs [2 ], outputs [3 ]),
76+ 2 ,
77+ ):
78+ assert output_a != output_b
79+
80+ # verify requests with the same seed match
81+ assert outputs [1 ] == outputs [4 ]
82+ assert outputs [2 ] == outputs [5 ]
0 commit comments