22
33import numpy as np
44import pytest
5+ import pytest_asyncio
56from transformers import AutoModel , AutoTokenizer , BatchEncoding
67
8+ from tests .utils import RemoteOpenAIServer
79from vllm .sequence import SampleLogprobs
810from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE
911
1719VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
1820HF_PLACEHOLDER = "<|audio|>"
1921
22+ CHUNKED_PREFILL_KWARGS = {
23+ "enable_chunked_prefill" : True ,
24+ "max_num_seqs" : 2 ,
25+ # Use a very small limit to exercise chunked prefill.
26+ "max_num_batched_tokens" : 16
27+ }
28+
2029
2130@pytest .fixture (scope = "session" )
2231def audio_assets ():
@@ -30,6 +39,26 @@ def audio(request):
3039 return AudioAsset (request .param )
3140
3241
42+ @pytest .fixture (params = ({}, CHUNKED_PREFILL_KWARGS ))
43+ def server (request , audio_assets ):
44+ args = [
45+ "--dtype=bfloat16" , "--max-model-len=4096" , "--enforce-eager" ,
46+ f"--limit-mm-per-prompt=audio={ len (audio_assets )} "
47+ ] + [
48+ f"--{ key .replace ('_' ,'-' )} ={ value } "
49+ for key , value in request .param .items ()
50+ ]
51+
52+ with RemoteOpenAIServer (MODEL_NAME , args ) as remote_server :
53+ yield remote_server
54+
55+
56+ @pytest_asyncio .fixture
57+ async def client (server ):
58+ async with server .get_async_client () as async_client :
59+ yield async_client
60+
61+
3362def _get_prompt (audio_count , question , placeholder ):
3463 tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
3564 placeholder = f"{ placeholder } \n " * audio_count
@@ -68,8 +97,7 @@ def run_test(
6897 dtype : str ,
6998 max_tokens : int ,
7099 num_logprobs : int ,
71- tensor_parallel_size : int ,
72- distributed_executor_backend : Optional [str ] = None ,
100+ ** kwargs ,
73101):
74102 """Inference result should be the same between hf and vllm."""
75103 torch_dtype = STR_DTYPE_TO_TORCH_DTYPE [dtype ]
@@ -79,11 +107,8 @@ def run_test(
79107 # if we run HF first, the cuda initialization will be done and it
80108 # will hurt multiprocessing backend with fork method (the default method).
81109
82- with vllm_runner (model ,
83- dtype = dtype ,
84- tensor_parallel_size = tensor_parallel_size ,
85- distributed_executor_backend = distributed_executor_backend ,
86- enforce_eager = True ) as vllm_model :
110+ with vllm_runner (model , dtype = dtype , enforce_eager = True ,
111+ ** kwargs ) as vllm_model :
87112 vllm_outputs_per_audio = [
88113 vllm_model .generate_greedy_logprobs ([vllm_prompt ],
89114 max_tokens ,
@@ -135,18 +160,16 @@ def run_multi_audio_test(
135160 dtype : str ,
136161 max_tokens : int ,
137162 num_logprobs : int ,
138- tensor_parallel_size : int ,
139- distributed_executor_backend : Optional [str ] = None ,
163+ ** kwargs ,
140164):
141165 with vllm_runner (model ,
142166 dtype = dtype ,
143- tensor_parallel_size = tensor_parallel_size ,
144- distributed_executor_backend = distributed_executor_backend ,
145167 enforce_eager = True ,
146168 limit_mm_per_prompt = {
147169 "audio" :
148170 max ((len (audio ) for _ , audio in prompts_and_audios ))
149- }) as vllm_model :
171+ },
172+ ** kwargs ) as vllm_model :
150173 vllm_outputs = vllm_model .generate_greedy_logprobs (
151174 [prompt for prompt , _ in prompts_and_audios ],
152175 max_tokens ,
@@ -162,8 +185,9 @@ def run_multi_audio_test(
162185@pytest .mark .parametrize ("dtype" , ["half" ])
163186@pytest .mark .parametrize ("max_tokens" , [128 ])
164187@pytest .mark .parametrize ("num_logprobs" , [5 ])
188+ @pytest .mark .parametrize ("vllm_kwargs" , [{}, CHUNKED_PREFILL_KWARGS ])
165189def test_models (hf_runner , vllm_runner , audio , dtype : str , max_tokens : int ,
166- num_logprobs : int ) -> None :
190+ num_logprobs : int , vllm_kwargs : dict ) -> None :
167191
168192 vllm_prompt = _get_prompt (1 , "Describe the audio above." , VLLM_PLACEHOLDER )
169193 hf_prompt = _get_prompt (1 , "Describe the audio above." , HF_PLACEHOLDER )
@@ -175,17 +199,18 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
175199 dtype = dtype ,
176200 max_tokens = max_tokens ,
177201 num_logprobs = num_logprobs ,
178- tensor_parallel_size = 1 ,
202+ ** vllm_kwargs ,
179203 )
180204
181205
182206@pytest .mark .core_model
183207@pytest .mark .parametrize ("dtype" , ["half" ])
184208@pytest .mark .parametrize ("max_tokens" , [128 ])
185209@pytest .mark .parametrize ("num_logprobs" , [5 ])
210+ @pytest .mark .parametrize ("vllm_kwargs" , [{}, CHUNKED_PREFILL_KWARGS ])
186211def test_models_with_multiple_audios (vllm_runner , audio_assets , dtype : str ,
187- max_tokens : int ,
188- num_logprobs : int ) -> None :
212+ max_tokens : int , num_logprobs : int ,
213+ vllm_kwargs : dict ) -> None :
189214
190215 vllm_prompt = _get_prompt (len (audio_assets ),
191216 "Describe each of the audios above." ,
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
198223 dtype = dtype ,
199224 max_tokens = max_tokens ,
200225 num_logprobs = num_logprobs ,
201- tensor_parallel_size = 1 ,
226+ ** vllm_kwargs ,
202227 )
228+
229+
230+ @pytest .mark .asyncio
231+ async def test_online_inference (client , audio_assets ):
232+ """Exercises online inference with/without chunked prefill enabled."""
233+
234+ messages = [{
235+ "role" :
236+ "user" ,
237+ "content" : [
238+ * [{
239+ "type" : "audio_url" ,
240+ "audio_url" : {
241+ "url" : audio .url
242+ }
243+ } for audio in audio_assets ],
244+ {
245+ "type" :
246+ "text" ,
247+ "text" :
248+ f"What's happening in these { len (audio_assets )} audio clips?"
249+ },
250+ ],
251+ }]
252+
253+ chat_completion = await client .chat .completions .create (model = MODEL_NAME ,
254+ messages = messages ,
255+ max_tokens = 10 )
256+
257+ assert len (chat_completion .choices ) == 1
258+ choice = chat_completion .choices [0 ]
259+ assert choice .finish_reason == "length"
0 commit comments