11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+
45import vllm
56from vllm .lora .request import LoRARequest
67
2829 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1" , # noqa: E501
2930]
3031
32+ EXPECTED_BASE_MODEL_OUTPUT = [
33+ "SELECT COUNT(Candidate_ID) FROM candidate" ,
34+ "SELECT COUNT(Candidate_ID) FROM candidate" ,
35+ "SELECT Candidate_ID, COUNT(*) as Total_Candidates\n FROM candidate\n INNER JOIN people ON candidate.People_ID = people.People_ID" , # noqa: E501
36+ "SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1" , # noqa: E501
37+ ]
38+
3139
32- def generate_and_test (llm : vllm .LLM , lora_path : str , lora_id : int ) -> None :
40+ def generate_and_test (
41+ llm : vllm .LLM , lora_path : str , lora_id : list [int | None ] | int | None
42+ ) -> None :
3343 prompts = [
3444 PROMPT_TEMPLATE .format (context = "How many candidates are there?" ),
3545 PROMPT_TEMPLATE .format (context = "Count the number of candidates." ),
@@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
4050 context = "Return the poll resource associated with the most candidates."
4151 ),
4252 ]
53+
54+ lora_request = None
55+ if isinstance (lora_id , int ):
56+ lora_request = LoRARequest (str (lora_id ), lora_id , lora_path )
57+ elif isinstance (lora_id , list ):
58+ lora_request = [
59+ LoRARequest (str (i ), i , lora_path ) if i is not None else None
60+ for i in lora_id
61+ ]
62+
4363 sampling_params = vllm .SamplingParams (temperature = 0 , max_tokens = 64 )
44- outputs = llm .generate (
45- prompts ,
46- sampling_params ,
47- lora_request = LoRARequest (str (lora_id ), lora_id , lora_path ) if lora_id else None ,
48- )
64+ outputs = llm .generate (prompts , sampling_params , lora_request = lora_request )
4965 # Print the outputs.
5066 generated_texts : list [str ] = []
5167 for output in outputs :
@@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
5571 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
5672
5773 for i in range (len (EXPECTED_LORA_OUTPUT )):
58- assert generated_texts [i ].startswith (EXPECTED_LORA_OUTPUT [i ])
74+ req_lora_id = lora_id [i ] if isinstance (lora_id , list ) else lora_id
75+ expected_output = (
76+ EXPECTED_LORA_OUTPUT [i ]
77+ if req_lora_id is not None
78+ else EXPECTED_BASE_MODEL_OUTPUT [i ]
79+ )
80+ assert generated_texts [i ].startswith (expected_output )
5981
6082
6183def test_olmoe_lora (olmoe_lora_files ):
@@ -75,6 +97,34 @@ def test_olmoe_lora(olmoe_lora_files):
7597 generate_and_test (llm , olmoe_lora_files , lora_id = 2 )
7698
7799
100+ def test_olmoe_lora_base_model (olmoe_lora_files ):
101+ llm = vllm .LLM (
102+ MODEL_PATH ,
103+ max_model_len = 1024 ,
104+ enable_lora = True ,
105+ max_loras = 4 ,
106+ enforce_eager = True ,
107+ trust_remote_code = True ,
108+ enable_chunked_prefill = True ,
109+ )
110+
111+ generate_and_test (llm , olmoe_lora_files , lora_id = None )
112+
113+
114+ def test_olmoe_lora_mixed (olmoe_lora_files ):
115+ llm = vllm .LLM (
116+ MODEL_PATH ,
117+ max_model_len = 1024 ,
118+ enable_lora = True ,
119+ max_loras = 4 ,
120+ enforce_eager = True ,
121+ trust_remote_code = True ,
122+ enable_chunked_prefill = True ,
123+ )
124+
125+ generate_and_test (llm , olmoe_lora_files , lora_id = [1 , None , 3 , None ])
126+
127+
78128@multi_gpu_test (num_gpus = 2 )
79129def test_olmoe_lora_tp2 (olmoe_lora_files ):
80130 llm = vllm .LLM (
0 commit comments