@@ -34,9 +34,8 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
3434 return model_runner
3535
3636
37- @pytest .mark .parametrize ("batch_size, prompt_embeds_ratio" ,
38- list (itertools .product (range (1 , 257 ),
39- (0.0 , 0.5 , 1.0 ))))
37+ @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 , 3 )))
38+ @pytest .mark .parametrize ("prompt_embeds_ratio" , (0.0 , 0.5 , 1.0 ))
4039def test_prepare_prompt (batch_size , prompt_embeds_ratio ):
4140 model_runner = _create_model_runner (
4241 "facebook/opt-125m" ,
@@ -54,11 +53,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
5453 seq_len = i % (model_runner .block_size - 1 ) + 1
5554 seq_lens .append (seq_len )
5655 if random .random () < prompt_embeds_ratio :
57- seq_data = SequenceData ([], prompt_embeds = torch .rand (seq_len , 10 ))
56+ seq_data = SequenceData (
57+ array (VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )),
58+ torch .rand (seq_len , 10 ))
5859 input_embeds_len += seq_len
59- else
60- seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE ,
61- range (seq_len )))
60+ else :
61+ seq_data = SequenceData (
62+ array ( VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )))
6263 seq_group_metadata = SequenceGroupMetadata (
6364 request_id = f"test_{ i } " ,
6465 is_prompt = True ,
@@ -163,7 +164,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
163164 torch .testing .assert_close (actual , expected )
164165
165166
166- @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 )))
167+ @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 , 3 )))
167168@pytest .mark .parametrize ("prompt_embeds_ratio" , (0.0 , 0.5 , 1.0 ))
168169def test_prepare_decode_cuda_graph (batch_size , prompt_embeds_ratio ):
169170 model_runner = _create_model_runner (
@@ -185,8 +186,8 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio):
185186 context_len = i % (model_runner .block_size - 1 ) + 1
186187 context_lens .append (context_len )
187188 if random .random () < prompt_embeds_ratio :
188- seq_data = SequenceData ([] ,
189- prompt_embeds = torch .rand (context_len , 10 ))
189+ seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE , range ( 0 )) ,
190+ torch .rand (context_len , 10 ))
190191 input_embeds_len += context_len
191192 else :
192193 seq_data = SequenceData (
@@ -337,7 +338,7 @@ def distributed_init():
337338 ensure_model_parallel_initialized (1 , 1 )
338339
339340
340- @pytest .mark .parametrize ("batch_size" , list (range (2 , 128 )))
341+ @pytest .mark .parametrize ("batch_size" , list (range (2 , 128 , 3 )))
341342@pytest .mark .parametrize ("enforce_eager" , [True , False ])
342343@pytest .mark .parametrize ('prompt_embeds_ratio' , [0.0 , 0.5 , 1.0 ])
343344def test_hybrid_batches (batch_size , enforce_eager , prompt_embeds_ratio ,
@@ -366,11 +367,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
366367 seq_len = i % (model_runner .block_size - 1 ) + 1
367368 seq_lens .append (seq_len )
368369 if random .random () < prompt_embeds_ratio :
369- seq_data = SequenceData ([], prompt_embeds = torch .rand (seq_len , 10 ))
370+ seq_data = SequenceData (array (VLLM_TOKEN_ID_ARRAY_TYPE , range (0 )),
371+ torch .rand (seq_len , 10 ))
370372 input_embeds_len += seq_len
371373 else :
372- seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE ,
373- range (seq_len )))
374+ seq_data = SequenceData (
375+ array ( VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )))
374376 seq_group_metadata = SequenceGroupMetadata (
375377 request_id = f"test_{ i } " ,
376378 is_prompt = True ,
@@ -387,8 +389,8 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
387389 # make sure all tokens fit into one block
388390 context_len = i % (model_runner .block_size - 1 ) + 1
389391 if random .random () < prompt_embeds_ratio :
390- seq_data = SequenceData ([] ,
391- prompt_embeds = torch .rand (context_len , 10 ))
392+ seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE , range ( 0 )) ,
393+ torch .rand (context_len , 10 ))
392394 else :
393395 prompt_toks = array (VLLM_TOKEN_ID_ARRAY_TYPE , range (context_len ))
394396 seq_data = SequenceData (prompt_toks )
0 commit comments