Skip to content

Commit e90736f

Browse files
fixup! [Lora][Spec Decode] support LoRA with speculative decoding for v1 on gpu
Signed-off-by: Sean Chen <[email protected]>
1 parent 202c2cc commit e90736f

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This script contains:
5+
1. test lora with speculative decoding for batch inference
6+
"""
7+
import pytest
8+
import torch
9+
10+
from vllm import LLM, SamplingParams
11+
from vllm.distributed import cleanup_dist_env_and_memory
12+
from vllm.lora.request import LoRARequest
13+
from vllm.platforms import current_platform
14+
15+
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
16+
17+
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
18+
### INSTRUCTION:
19+
You are an AI assistant that generates Python code to solve linear
20+
algebra problems.
21+
22+
### PROBLEM:
23+
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
24+
[[4, 0, 1],
25+
[-2, 1, 0],
26+
[-2, 0, 1]]
27+
28+
### PYTHON SOLUTION:
29+
"""
30+
31+
32+
@pytest.mark.skipif(not current_platform.is_cuda(),
33+
reason="CUDA not available")
34+
@pytest.mark.parametrize(
35+
"model_setup",
36+
[("eagle3", "Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3",
37+
"premjatin/qwen-linear-algebra-coder", 1)])
38+
def test_batch_inference_correctness(
39+
monkeypatch: pytest.MonkeyPatch,
40+
model_setup: tuple[str, str, str, str, int],
41+
):
42+
'''
43+
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
44+
Should be the same and no failure when doing batch inference.
45+
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
46+
'''
47+
with monkeypatch.context() as m:
48+
m.setenv("VLLM_USE_V1", "1")
49+
50+
method, model_name, spec_model_name, lora_path, tp_size = model_setup
51+
52+
# without speculative decoding
53+
ref_llm = LLM(
54+
model=model_name,
55+
trust_remote_code=True,
56+
tensor_parallel_size=tp_size,
57+
max_model_len=2048,
58+
max_num_seqs=4,
59+
enable_lora=True,
60+
max_loras=1,
61+
max_cpu_loras=1,
62+
max_lora_rank=16,
63+
)
64+
65+
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 4
66+
lora_request = LoRARequest("adapter", 1, lora_path)
67+
sampling_params = SamplingParams(temperature=0, max_tokens=128)
68+
69+
ref_outputs = ref_llm.generate(prompts,
70+
sampling_params,
71+
lora_request=lora_request)
72+
del ref_llm
73+
torch.cuda.empty_cache()
74+
cleanup_dist_env_and_memory()
75+
76+
lora_spec_llm = LLM(
77+
model=model_name,
78+
trust_remote_code=True,
79+
tensor_parallel_size=tp_size,
80+
speculative_config={
81+
"method": method,
82+
"model": spec_model_name,
83+
"num_speculative_tokens": 3,
84+
"max_model_len": 2048,
85+
},
86+
max_model_len=2048,
87+
max_num_seqs=4,
88+
enable_lora=True,
89+
max_loras=1,
90+
max_cpu_loras=1,
91+
max_lora_rank=16,
92+
)
93+
94+
lora_spec_outputs = lora_spec_llm.generate(prompts,
95+
sampling_params,
96+
lora_request=lora_request)
97+
98+
matches = 0
99+
misses = 0
100+
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
101+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
102+
matches += 1
103+
else:
104+
misses += 1
105+
print(f"ref_output: {ref_output.outputs[0].text}")
106+
print(f"spec_output: {spec_output.outputs[0].text}")
107+
108+
assert misses == 0
109+
del lora_spec_llm
110+
torch.cuda.empty_cache()
111+
cleanup_dist_env_and_memory()

0 commit comments

Comments
 (0)