1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import List , Optional
14+ from typing import List , Optional , Union
1515
1616import pytest
1717import pytest_asyncio
@@ -42,7 +42,14 @@ def get_gpu_to_model_uid(self):
4242 def get_gpu_to_embedding_model_uids (self ):
4343 return self ._gpu_to_embedding_model_uids
4444
45+ def get_user_specified_gpu_to_model_uids (self ):
46+ return self ._user_specified_gpu_to_model_uids
47+
4548 async def is_model_vllm_backend (self , model_uid ):
49+ if model_uid .startswith ("normal_" ):
50+ return False
51+ if model_uid .startswith ("vllm_" ):
52+ return True
4653 for _dev in self ._gpu_to_model_uid :
4754 if model_uid == self ._gpu_to_model_uid [_dev ]:
4855 return True
@@ -57,10 +64,11 @@ async def launch_builtin_model(
5764 quantization : Optional [str ],
5865 model_type : str = "LLM" ,
5966 n_gpu : Optional [int ] = None ,
67+ gpu_idx : Optional [Union [int , List [int ]]] = None ,
6068 ** kwargs ,
6169 ):
6270 subpool_address , devices = await self ._create_subpool (
63- model_uid , model_type , n_gpu = n_gpu
71+ model_uid , model_type , n_gpu = n_gpu , gpu_idx = gpu_idx # type: ignore
6472 )
6573 self ._model_uid_to_addr [model_uid ] = subpool_address
6674
@@ -252,3 +260,178 @@ async def test_launch_embedding_model(setup_pool):
252260 )
253261 for i in range (1 , 6 ):
254262 await worker .terminate_model (f"model_model_{ i } " )
263+
264+
265+ @pytest .mark .asyncio
266+ async def test_launch_model_with_gpu_idx (setup_pool ):
267+ pool = setup_pool
268+ addr = pool .external_address
269+
270+ worker : xo .ActorRefType ["MockWorkerActor" ] = await xo .create_actor (
271+ MockWorkerActor ,
272+ address = addr ,
273+ uid = WorkerActor .uid (),
274+ supervisor_address = "test" ,
275+ main_pool = pool ,
276+ cuda_devices = [i for i in range (4 )],
277+ )
278+
279+ # test normal model
280+ await worker .launch_builtin_model (
281+ "normal_model_model_1" , "mock_model_name" , None , None , None , "LLM" , n_gpu = 1
282+ )
283+ llm_info = await worker .get_gpu_to_model_uid ()
284+ assert len (llm_info ) == 1
285+ assert 0 in llm_info
286+
287+ await worker .launch_builtin_model (
288+ "model_model_2" , "mock_model_name" , None , None , None , "LLM" , gpu_idx = [0 ]
289+ )
290+ llm_info = await worker .get_gpu_to_model_uid ()
291+ assert len (llm_info ) == 1
292+ assert 0 in llm_info
293+
294+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
295+ assert len (user_specified_info ) == 1
296+ assert 0 in user_specified_info
297+ assert len (user_specified_info [0 ]) == 1
298+ assert list (user_specified_info [0 ])[0 ][0 ] == "model_model_2"
299+ assert list (user_specified_info [0 ])[0 ][1 ] == "LLM"
300+
301+ # test vllm model
302+ await worker .launch_builtin_model (
303+ "vllm_model_model_3" , "mock_model_name" , None , None , None , "LLM" , n_gpu = 1
304+ )
305+ llm_info = await worker .get_gpu_to_model_uid ()
306+ assert len (llm_info ) == 2
307+ assert 0 in llm_info
308+ assert 1 in llm_info
309+
310+ with pytest .raises (RuntimeError ):
311+ await worker .launch_builtin_model (
312+ "model_model_4" , "mock_model_name" , None , None , None , "LLM" , gpu_idx = [1 ]
313+ )
314+
315+ await worker .launch_builtin_model (
316+ "model_model_4" , "mock_model_name" , None , None , None , "LLM" , gpu_idx = [2 ]
317+ )
318+ llm_info = await worker .get_gpu_to_model_uid ()
319+ assert len (llm_info ) == 2
320+ assert 0 in llm_info
321+ assert 1 in llm_info
322+
323+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
324+ assert len (user_specified_info ) == 2
325+ assert 0 in user_specified_info
326+ assert 2 in user_specified_info
327+ assert len (user_specified_info [2 ]) == 1
328+ assert list (user_specified_info [2 ])[0 ][0 ] == "model_model_4"
329+ assert list (user_specified_info [2 ])[0 ][1 ] == "LLM"
330+
331+ # then launch a LLM without gpu_idx
332+ await worker .launch_builtin_model (
333+ "normal_model_model_5" , "mock_model_name" , None , None , None , "LLM" , n_gpu = 1
334+ )
335+ llm_info = await worker .get_gpu_to_model_uid ()
336+ assert len (llm_info ) == 3
337+ assert 0 in llm_info
338+ assert 1 in llm_info
339+ assert 3 in llm_info
340+
341+ # launch without gpu_idx again, error
342+ with pytest .raises (RuntimeError ):
343+ await worker .launch_builtin_model (
344+ "normal_model_model_6" , "mock_model_name" , None , None , None , "LLM" , n_gpu = 1
345+ )
346+
347+ # test terminate and cleanup
348+ await worker .terminate_model ("normal_model_model_1" )
349+ await worker .terminate_model ("model_model_2" )
350+ await worker .terminate_model ("vllm_model_model_3" )
351+ await worker .terminate_model ("model_model_4" )
352+ await worker .terminate_model ("normal_model_model_5" )
353+
354+ llm_info = await worker .get_gpu_to_model_uid ()
355+ assert len (llm_info ) == 0
356+
357+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
358+ for idx , model_infos in user_specified_info .items ():
359+ assert len (model_infos ) == 0
360+
361+ # next, test with embedding models
362+ await worker .launch_builtin_model (
363+ "embedding_1" , "mock_model_name" , None , None , None , "embedding" , n_gpu = 1
364+ )
365+ embedding_info = await worker .get_gpu_to_embedding_model_uids ()
366+ assert len (embedding_info ) == 1
367+ assert 0 in embedding_info
368+
369+ await worker .launch_builtin_model (
370+ "vllm_mock_model_2" , "mock_model_name" , None , None , None , "LLM" , gpu_idx = [0 ]
371+ )
372+ embedding_info = await worker .get_gpu_to_embedding_model_uids ()
373+ assert len (embedding_info ) == 1
374+ assert 0 in embedding_info
375+
376+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
377+ assert len (user_specified_info [0 ]) == 1
378+ assert list (user_specified_info [0 ])[0 ][0 ] == "vllm_mock_model_2"
379+ assert list (user_specified_info [0 ])[0 ][1 ] == "LLM"
380+
381+ # already has vllm model on gpu 0, error
382+ with pytest .raises (RuntimeError ):
383+ await worker .launch_builtin_model (
384+ "rerank_3" , "mock_model_name" , None , None , None , "rerank" , gpu_idx = [0 ]
385+ )
386+ # never choose gpu 0 again
387+ with pytest .raises (RuntimeError ):
388+ await worker .launch_builtin_model (
389+ "normal_mock_model_3" , "mock_model_name" , None , None , None , "LLM" , n_gpu = 4
390+ )
391+
392+ # should be on gpu 1
393+ await worker .launch_builtin_model (
394+ "embedding_3" , "mock_model_name" , None , None , None , "embedding" , n_gpu = 1
395+ )
396+ # should be on gpu 1
397+ await worker .launch_builtin_model (
398+ "rerank_4" , "mock_model_name" , None , None , None , "rerank" , gpu_idx = [1 ]
399+ )
400+ # should be on gpu 2
401+ await worker .launch_builtin_model (
402+ "embedding_5" , "mock_model_name" , None , None , None , "embedding" , n_gpu = 1
403+ )
404+ # should be on gpu 3
405+ await worker .launch_builtin_model (
406+ "rerank_6" , "mock_model_name" , None , None , None , "rerank" , n_gpu = 1
407+ )
408+ # should be on gpu 2, due to there are the fewest models on it
409+ await worker .launch_builtin_model (
410+ "rerank_7" , "mock_model_name" , None , None , None , "rerank" , n_gpu = 1
411+ )
412+ embedding_info = await worker .get_gpu_to_embedding_model_uids ()
413+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
414+ assert "rerank_7" in embedding_info [2 ]
415+ assert len (embedding_info [0 ]) == 1
416+ assert len (user_specified_info [0 ]) == 1
417+ assert len (embedding_info [1 ]) == 1
418+ assert len (user_specified_info [1 ]) == 1
419+ assert len (embedding_info [2 ]) == 2
420+ assert len (user_specified_info [2 ]) == 0
421+ assert len (embedding_info [3 ]) == 1
422+ assert len (user_specified_info [3 ]) == 0
423+
424+ # cleanup
425+ await worker .terminate_model ("embedding_1" )
426+ await worker .terminate_model ("vllm_mock_model_2" )
427+ await worker .terminate_model ("embedding_3" )
428+ await worker .terminate_model ("rerank_4" )
429+ await worker .terminate_model ("embedding_5" )
430+ await worker .terminate_model ("rerank_6" )
431+ await worker .terminate_model ("rerank_7" )
432+
433+ embedding_info = await worker .get_gpu_to_embedding_model_uids ()
434+ user_specified_info = await worker .get_user_specified_gpu_to_model_uids ()
435+ for info in [embedding_info , user_specified_info ]:
436+ for dev , details in info .items ():
437+ assert len (details ) == 0
0 commit comments