33from tempfile import TemporaryDirectory
44from typing import List , Tuple
55
6+ import torch
67from huggingface_hub import snapshot_download
78from safetensors .torch import load_file , save_file
89from transformers import AutoTokenizer
910
1011from vllm .lora .request import LoRARequest
1112
13+ from ..models .utils import check_outputs_equal
14+
1215ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
1316LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
1417
@@ -62,19 +65,24 @@ def test_ultravox_lora(vllm_runner):
6265 """
6366 TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
6467 """
68+ # Workaround to prevent device mismatch in Whisper.
69+ # Can be removed when it is fixed upstream in transformer
70+ # https://github.com/huggingface/transformers/pull/35866
71+ torch .set_default_device ("cpu" )
72+
6573 llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path ()
6674 with TemporaryDirectory () as temp_ultravox_lora_dir :
6775 llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora (
6876 llama3_1_8b_chess_lora , temp_ultravox_lora_dir )
6977 with vllm_runner (
7078 ULTRAVOX_MODEL_NAME ,
7179 enforce_eager = True ,
72- max_num_seqs = 128 ,
80+ max_num_seqs = 2 ,
7381 enable_lora = True ,
74- max_loras = 4 ,
82+ max_loras = 1 ,
7583 max_lora_rank = 128 ,
7684 dtype = "bfloat16" ,
77- max_model_len = 4096 ,
85+ max_model_len = 1024 ,
7886 ) as vllm_model :
7987 ultravox_outputs : List [Tuple [
8088 List [int ], str ]] = vllm_model .generate_greedy (
@@ -91,21 +99,23 @@ def test_ultravox_lora(vllm_runner):
9199 with vllm_runner (
92100 LLMA_MODEL_NAME ,
93101 enforce_eager = True ,
94- max_num_seqs = 128 ,
102+ max_num_seqs = 2 ,
95103 enable_lora = True ,
96- max_loras = 4 ,
104+ max_loras = 1 ,
97105 max_lora_rank = 128 ,
98106 dtype = "bfloat16" ,
99- max_model_len = 4096 ,
107+ max_model_len = 1024 ,
100108 ) as vllm_model :
101- llama_outputs_no_lora : List [Tuple [List [int ], str ]] = (
109+ llama_outputs : List [Tuple [List [int ], str ]] = (
102110 vllm_model .generate_greedy (
103111 [_get_prompt (0 , PROMPT , VLLM_PLACEHOLDER , LLMA_MODEL_NAME )],
104112 256 ,
113+ lora_request = LoRARequest (str (1 ), 1 , llama3_1_8b_chess_lora ),
105114 ))
106115
107- _ , llama_no_lora_str = llama_outputs_no_lora [0 ]
108- _ , ultravox_str = ultravox_outputs [0 ]
109-
110- # verify that text don't match with no lora
111- assert llama_no_lora_str != ultravox_str
116+ check_outputs_equal (
117+ outputs_0_lst = ultravox_outputs ,
118+ outputs_1_lst = llama_outputs ,
119+ name_0 = "ultravox" ,
120+ name_1 = "llama" ,
121+ )
0 commit comments