Skip to content

Commit d0c7792

Browse files
xiaohongchen1991robertgshaw2-redhatdcmaddixli2haipeng
authored
[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding (#21068)
Signed-off-by: Sean Chen <[email protected]> Signed-off-by: Robert Shaw <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Danielle Robinson <[email protected]> Co-authored-by: Haipeng Li <[email protected]> Co-authored-by: li2haipeng <[email protected]>
1 parent b158df2 commit d0c7792

File tree

7 files changed

+201
-15
lines changed

7 files changed

+201
-15
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This script contains:
5+
1. test lora with speculative decoding for batch inference
6+
"""
7+
8+
import random
9+
10+
import numpy as np
11+
import pytest
12+
import torch
13+
14+
from vllm import LLM, SamplingParams
15+
from vllm.distributed import cleanup_dist_env_and_memory
16+
from vllm.lora.request import LoRARequest
17+
from vllm.platforms import current_platform
18+
19+
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
20+
21+
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
22+
### INSTRUCTION:
23+
You are an AI assistant that generates Python code to solve linear
24+
algebra problems.
25+
26+
### PROBLEM:
27+
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
28+
[[3, 2, 0],
29+
[2, 3, 0],
30+
[0, 0, 2]]
31+
32+
### OUTPUT FORMAT (STRICT):
33+
Numbers should be represented as integers only.
34+
35+
### PYTHON SOLUTION:
36+
"""
37+
38+
SEED = 42
39+
40+
41+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
42+
@pytest.mark.parametrize(
43+
"model_setup",
44+
[
45+
(
46+
"eagle3",
47+
"Qwen/Qwen3-1.7B",
48+
"AngelSlim/Qwen3-1.7B_eagle3",
49+
"premjatin/qwen-linear-algebra-coder",
50+
1,
51+
)
52+
],
53+
)
54+
def test_batch_inference_correctness(
55+
monkeypatch: pytest.MonkeyPatch,
56+
model_setup: tuple[str, str, str, str, int],
57+
):
58+
"""
59+
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
60+
Should be the same and no failure when doing batch inference.
61+
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
62+
"""
63+
with monkeypatch.context() as m:
64+
m.setenv("VLLM_USE_V1", "1")
65+
66+
# Disable randomness
67+
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
68+
torch.manual_seed(SEED)
69+
np.random.seed(SEED)
70+
random.seed(SEED)
71+
torch.cuda.manual_seed_all(SEED)
72+
torch.backends.cudnn.benchmark = False
73+
torch.backends.cudnn.deterministic = True
74+
75+
method, model_name, spec_model_name, lora_path, tp_size = model_setup
76+
77+
# without speculative decoding
78+
ref_llm = LLM(
79+
model=model_name,
80+
trust_remote_code=True,
81+
tensor_parallel_size=tp_size,
82+
max_model_len=2048,
83+
max_num_seqs=4,
84+
enable_lora=True,
85+
max_loras=1,
86+
max_cpu_loras=1,
87+
max_lora_rank=16,
88+
)
89+
90+
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
91+
lora_request = LoRARequest("adapter", 1, lora_path)
92+
sampling_params = SamplingParams(
93+
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
94+
)
95+
96+
ref_outputs = ref_llm.generate(
97+
prompts, sampling_params, lora_request=lora_request
98+
)
99+
del ref_llm
100+
torch.cuda.empty_cache()
101+
cleanup_dist_env_and_memory()
102+
103+
lora_spec_llm = LLM(
104+
model=model_name,
105+
trust_remote_code=True,
106+
tensor_parallel_size=tp_size,
107+
speculative_config={
108+
"method": method,
109+
"model": spec_model_name,
110+
"num_speculative_tokens": 3,
111+
"max_model_len": 2048,
112+
},
113+
max_model_len=2048,
114+
max_num_seqs=4,
115+
enable_lora=True,
116+
max_loras=1,
117+
max_cpu_loras=1,
118+
max_lora_rank=16,
119+
)
120+
121+
lora_spec_outputs = lora_spec_llm.generate(
122+
prompts, sampling_params, lora_request=lora_request
123+
)
124+
125+
matches = 0
126+
misses = 0
127+
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
128+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
129+
matches += 1
130+
else:
131+
misses += 1
132+
print(f"ref_output: {ref_output.outputs[0].text}")
133+
print(f"spec_output: {spec_output.outputs[0].text}")
134+
135+
# Heuristic: expect at least 90% of the prompts to match exactly
136+
# Upon failure, inspect the outputs to check for inaccuracy.
137+
print(f"match ratio: {matches}/{len(ref_outputs)}")
138+
assert matches > int(0.90 * len(ref_outputs))
139+
del lora_spec_llm
140+
torch.cuda.empty_cache()
141+
cleanup_dist_env_and_memory()

vllm/engine/arg_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,20 @@ def create_engine_config(
15741574
else None
15751575
)
15761576

1577+
if (
1578+
lora_config is not None
1579+
and speculative_config is not None
1580+
and scheduler_config.max_num_batched_tokens
1581+
< (
1582+
scheduler_config.max_num_seqs
1583+
* (speculative_config.num_speculative_tokens + 1)
1584+
)
1585+
):
1586+
raise ValueError(
1587+
"Consider increasing max_num_batched_tokens or "
1588+
"decreasing num_speculative_tokens"
1589+
)
1590+
15771591
# bitsandbytes pre-quantized model need a specific model loader
15781592
if model_config.quantization == "bitsandbytes":
15791593
self.quantization = self.load_format = "bitsandbytes"

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def __init__(
5151
self.max_loras, max_num_batched_tokens, device=device
5252
)
5353

54+
# When speculative decoding is enabled, max_num_samples is
55+
# max_batches * (num_speculative_decoding_tokens + 1).
56+
# This line can be optimized by replacing max_num_batched_tokens
57+
# to max_batches * (num_speculative_decoding_tokens + 1).
5458
self.prompt_mapping_meta = LoRAKernelMeta.make(
55-
self.max_loras, max_batches, device=device
59+
self.max_loras, max_num_batched_tokens, device=device
5660
)
5761

5862
def update_metadata(

vllm/v1/worker/gpu_input_batch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,22 +859,24 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
859859
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
860860

861861
def make_lora_inputs(
862-
self, num_scheduled_tokens: np.ndarray
862+
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
863863
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
864864
"""
865865
Given the num_scheduled_tokens for each request in the batch, return
866866
datastructures used to activate the current LoRAs.
867867
Returns:
868-
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
869-
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
868+
1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
869+
where, prompt_lora_mapping[i] is the LoRA id to use for the ith
870+
sampled token.
870871
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
871872
where, token_lora_mapping[i] is the LoRA id to use for ith token.
872873
3. lora_requests: Set of relevant LoRA requests.
873874
"""
874875

875876
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
876-
prompt_lora_mapping = tuple(req_lora_mapping)
877+
prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
877878
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
879+
878880
active_lora_requests: set[LoRARequest] = set(
879881
self.lora_id_to_lora_request.values()
880882
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ def _prepare_inputs(
12681268
logits_indices = query_start_loc[1:] - 1
12691269
num_draft_tokens = None
12701270
spec_decode_metadata = None
1271+
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
12711272
else:
12721273
# Get the number of draft tokens for each request.
12731274
# Iterate over the dictionary rather than all requests since not all
@@ -1294,7 +1295,7 @@ def _prepare_inputs(
12941295
num_draft_tokens, cu_num_tokens
12951296
)
12961297
logits_indices = spec_decode_metadata.logits_indices
1297-
1298+
num_sampled_tokens = num_draft_tokens + 1
12981299
# For DECODE only cuda graph of some attention backends (e.g., GDN).
12991300
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
13001301
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
@@ -1445,7 +1446,13 @@ def _prepare_inputs(
14451446

14461447
# Hot-Swap lora model
14471448
if self.lora_config:
1448-
self.set_active_loras(self.input_batch, num_scheduled_tokens)
1449+
assert (
1450+
np.sum(num_sampled_tokens)
1451+
<= self.vllm_config.scheduler_config.max_num_batched_tokens
1452+
)
1453+
self.set_active_loras(
1454+
self.input_batch, num_scheduled_tokens, num_sampled_tokens
1455+
)
14491456

14501457
return (
14511458
attn_metadata,
@@ -3390,6 +3397,7 @@ def _dummy_run(
33903397
assert len(num_scheduled_tokens_list) == num_reqs
33913398
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
33923399
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
3400+
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
33933401

33943402
# Disable DP padding when running eager
33953403
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
@@ -3485,7 +3493,11 @@ def _dummy_run(
34853493
attn_metadata[layer_name] = attn_metadata_i
34863494

34873495
with self.maybe_dummy_run_with_lora(
3488-
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
3496+
self.lora_config,
3497+
num_scheduled_tokens,
3498+
num_sampled_tokens,
3499+
activate_lora,
3500+
remove_lora,
34893501
):
34903502
# Make sure padding doesn't exceed max_num_tokens
34913503
assert num_tokens_after_padding <= self.max_num_tokens

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def load_lora_model(
3838
"Regarding multimodal models, vLLM currently "
3939
"only supports adding LoRA to language model."
4040
)
41-
4241
# Add LoRA Manager to the Model Runner
4342
self.lora_manager = LRUCacheWorkerLoRAManager(
4443
vllm_config,
@@ -70,13 +69,19 @@ def _ensure_lora_enabled(self) -> None:
7069
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
7170

7271
def set_active_loras(
73-
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray
72+
self,
73+
input_batch: InputBatch,
74+
num_scheduled_tokens: np.ndarray,
75+
num_sampled_tokens: np.ndarray | None = None,
7476
) -> None:
75-
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
77+
if num_sampled_tokens is None:
78+
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
79+
80+
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
7681
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
7782
lora_requests: set[LoRARequest]
7883
prompt_lora_mapping, token_lora_mapping, lora_requests = (
79-
input_batch.make_lora_inputs(num_scheduled_tokens)
84+
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
8085
)
8186
return self._set_active_loras(
8287
prompt_lora_mapping, token_lora_mapping, lora_requests
@@ -123,8 +128,12 @@ def maybe_select_dummy_loras(
123128
self,
124129
lora_config: LoRAConfig | None,
125130
num_scheduled_tokens: np.ndarray,
131+
num_sampled_tokens: np.ndarray | None = None,
126132
activate_lora: bool = True,
127133
):
134+
if num_sampled_tokens is None:
135+
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
136+
128137
if lora_config is None:
129138
yield
130139
else:
@@ -143,6 +152,9 @@ def maybe_select_dummy_loras(
143152
else:
144153
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
145154

155+
# Make sample lora mapping
156+
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
157+
146158
# Make token lora mapping
147159
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
148160

@@ -157,7 +169,7 @@ def maybe_select_dummy_loras(
157169
}
158170

159171
self._set_active_loras(
160-
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests
172+
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
161173
)
162174

163175
yield
@@ -167,13 +179,14 @@ def maybe_dummy_run_with_lora(
167179
self,
168180
lora_config: LoRAConfig | None,
169181
num_scheduled_tokens: np.ndarray,
182+
num_sampled_tokens: np.ndarray,
170183
activate_lora: bool = True,
171184
remove_lora: bool = True,
172185
):
173186
with (
174187
self.maybe_setup_dummy_loras(lora_config, remove_lora),
175188
self.maybe_select_dummy_loras(
176-
lora_config, num_scheduled_tokens, activate_lora
189+
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
177190
),
178191
):
179192
yield

vllm/v1/worker/tpu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
526526
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
527527

528528
def make_lora_inputs(
529-
self, num_scheduled_tokens: np.ndarray
529+
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
530530
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
531531
"""
532532
Given the num_scheduled_tokens for each request in the batch, return

0 commit comments

Comments
 (0)