Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a040239

Browse files
committed
Squash 5733
Signed-off-by: Jefferson Fialho <[email protected]>
1 parent 2479a20 commit a040239

File tree

9 files changed

+442
-14
lines changed

9 files changed

+442
-14
lines changed

tests/lora/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def sql_lora_files(sql_lora_huggingface_id):
166166
return snapshot_download(repo_id=sql_lora_huggingface_id)
167167

168168

169+
@pytest.fixture(scope="session")
170+
def lora_bias_files():
171+
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
172+
173+
169174
@pytest.fixture(scope="session")
170175
def mixtral_lora_files():
171176
# Note: this module has incorrect adapter_config.json to test

tests/lora/test_lora_bias_e2e.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
import vllm
6+
from vllm.lora.request import LoRARequest
7+
8+
MODEL_PATH = "ibm-granite/granite-3b-code-base"
9+
10+
11+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
12+
prompts = [
13+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
14+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
15+
]
16+
sampling_params = vllm.SamplingParams(temperature=0,
17+
max_tokens=256,
18+
stop=["[/assistant]"])
19+
outputs = llm.generate(
20+
prompts,
21+
sampling_params,
22+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
23+
if lora_id else None)
24+
generated_texts: List[str] = []
25+
for output in outputs:
26+
generated_text = output.outputs[0].text
27+
generated_texts.append(generated_text)
28+
return generated_texts
29+
30+
31+
@pytest.mark.parametrize("lora_bias", [True, False])
32+
@pytest.mark.parametrize("fully_sharded", [True, False])
33+
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
34+
llm = vllm.LLM(MODEL_PATH,
35+
enable_lora=True,
36+
max_num_seqs=16,
37+
max_lora_rank=8,
38+
max_loras=1,
39+
enable_lora_bias=lora_bias,
40+
tensor_parallel_size=1,
41+
fully_sharded_loras=fully_sharded)
42+
43+
print("lora adapter created")
44+
output1 = do_sample(llm, lora_bias_files, lora_id=0)
45+
46+
print("lora")
47+
output2 = do_sample(llm, lora_bias_files, lora_id=1)
48+
49+
if lora_bias:
50+
assert output1 != output2
51+
else:
52+
assert output1 == output2

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,7 @@ class LoRAConfig:
15171517
# This is a constant.
15181518
lora_vocab_padding_size: ClassVar[int] = 256
15191519
long_lora_scaling_factors: Optional[Tuple[float]] = None
1520+
bias_enabled: bool = False
15201521

15211522
def __post_init__(self):
15221523
# Setting the maximum rank to 256 should be able to satisfy the vast

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class EngineArgs:
133133
tokenizer_pool_extra_config: Optional[dict] = None
134134
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
135135
enable_lora: bool = False
136+
enable_lora_bias: bool = False
136137
max_loras: int = 1
137138
max_lora_rank: int = 16
138139
enable_prompt_adapter: bool = False
@@ -526,6 +527,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
526527
parser.add_argument('--enable-lora',
527528
action='store_true',
528529
help='If True, enable handling of LoRA adapters.')
530+
parser.add_argument('--enable-lora-bias',
531+
action='store_true',
532+
help='If True, enable bias for LoRA adapters.')
529533
parser.add_argument('--max-loras',
530534
type=int,
531535
default=EngineArgs.max_loras,
@@ -1009,6 +1013,7 @@ def create_engine_config(self) -> EngineConfig:
10091013
and parallel_config.use_ray),
10101014
)
10111015
lora_config = LoRAConfig(
1016+
bias_enabled=self.enable_lora_bias,
10121017
max_lora_rank=self.max_lora_rank,
10131018
max_loras=self.max_loras,
10141019
fully_sharded_loras=self.fully_sharded_loras,

vllm/lora/fully_sharded_layers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def apply(self, x: torch.Tensor,
7070
self.lora_b_stacked,
7171
add_input=True)
7272
# now have column partitioned output
73+
74+
if self.bias_stacked is not None:
75+
self.bias_stacked = self.bias_stacked.view(
76+
-1, self.bias_stacked.shape[-1])
77+
self.bias_stacked = self.bias_stacked[
78+
self.punica_wrapper.token_lora_indices]
79+
output += self.bias_stacked
80+
7381
output = output.view(*out_orig_shape)
7482
return output
7583

@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
121129
left_offset = 0
122130
for idx in range(n):
123131
shard_size = layer.lora_b_stacked[idx].shape[2]
132+
133+
if layer.bias_stacked is not None:
134+
bias = layer.bias_stacked[idx]
135+
if bias is not None:
136+
bias = bias.view(-1, bias.shape[-1])
137+
bias = bias[layer.punica_wrapper.token_lora_indices]
138+
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
139+
output[:, left_offset:left_offset + shard_size] += bias
140+
124141
layer.punica_wrapper.add_expand_slice(
125142
output,
126143
buffers[idx],
@@ -295,6 +312,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
295312
lora_b = lora_b[:, start_idx:end_idx]
296313
return lora_b
297314

315+
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
316+
if bias is None:
317+
return bias
318+
shard_size = self.bias_stacked.shape[2]
319+
start_idx = self.tp_rank * shard_size
320+
end_idx = (self.tp_rank + 1) * shard_size
321+
bias = bias[start_idx:end_idx]
322+
return bias
323+
298324
def apply(self, x: torch.Tensor) -> torch.Tensor:
299325
output = self.base_layer.quant_method.apply(self.base_layer, x)
300326

@@ -318,6 +344,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
318344
# reduced before being used
319345
shard_size = self.lora_b_stacked.shape[2]
320346
start_idx = self.tp_rank * shard_size
347+
348+
if self.bias_stacked is not None:
349+
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
350+
bias = bias[self.punica_wrapper.token_lora_indices]
351+
bias[self.punica_wrapper.token_lora_indices == -1] = 0
352+
output += bias
353+
321354
self.punica_wrapper.add_expand_slice(output, buffer,
322355
self.lora_b_stacked, start_idx,
323356
shard_size)

0 commit comments

Comments
 (0)