Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import os
import shutil
from contextlib import suppress

Expand All @@ -17,6 +18,33 @@
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"

BADREQUEST_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]


@pytest.fixture(scope="module")
def zephyr_lora_files():
Expand Down Expand Up @@ -138,32 +166,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"

# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)

with open(invalid_rank / "adapter_config.json") as f:
@pytest.mark.parametrize("test_name,config_change,expected_error",
BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI,
tmp_path: str, zephyr_lora_files: str,
test_name: str, config_change: dict,
expected_error: str):
# Create test directory
test_dir = os.path.join(tmp_path, test_name)

# Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir)

# Load and modify configuration
config_path = os.path.join(test_dir, "adapter_config.json")
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)

print(adapter_config)

# assert False

# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)

with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
# Test loading the adapter
with pytest.raises(openai.BadRequestError, match=expected_error):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
"lora_name": test_name,
"lora_path": str(test_dir)
})


Expand Down
16 changes: 16 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

Expand Down Expand Up @@ -30,11 +31,14 @@ def test_load_checkpoints(
else:
expected_lora_modules.append(module)
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -43,19 +47,25 @@ def test_load_checkpoints(
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -64,10 +74,13 @@ def test_load_checkpoints(
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
max_position_embeddings=4096)
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
".layers.": ".baichuan_layers.",
},
)
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
3 changes: 3 additions & 0 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

Expand All @@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path = get_adapter_absolute_path(lora_name)

# lora loading should work for either absolute path and hugggingface id.
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
Expand Down
24 changes: 8 additions & 16 deletions vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,16 @@ async def load_lora_adapter(
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except ValueError as e:
# Adapter not found or lora configuration errors
if "No adapter found" in str(e):
return create_error_response(message=str(e),
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return create_error_response(
message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
except BaseException as e:
# Some other unexpected problem loading the adapter, e.g. malformed
# input files.
# More detailed error messages for the user would be nicer here
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if isinstance(e, ValueError) and "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND

return create_error_response(message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
err_type=error_type,
status_code=status_code)

self.lora_requests.append(lora_request)
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
Expand Down
9 changes: 2 additions & 7 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import json
import math
import os
import re
Expand Down Expand Up @@ -180,8 +179,8 @@ def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
peft_helper: PEFTHelper,
*,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
Expand All @@ -207,18 +206,14 @@ def from_local_checkpoint(
Returns:
Loaded LoRA Model.
"""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
# lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)

config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
Expand Down
49 changes: 41 additions & 8 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union
from typing import List, Literal, Optional, Union

from vllm.config import LoRAConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclass
class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""

# Required fields
r: int
lora_alpha: int
Expand All @@ -29,20 +38,18 @@ class PEFTHelper:
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)

def _validate_features(self):
def _validate_features(self) -> List[str]:
"""
Check if there are any unsupported Lora features.
"""
error_msg = []

if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")

if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")

if error_msg:
raise ValueError(f"{', '.join(error_msg)}")
return error_msg

def __post_init__(self):
self._validate_features()
if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
Expand Down Expand Up @@ -78,3 +85,29 @@ def from_dict(cls, config_dict: dict) -> "PEFTHelper":
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)

@classmethod
def from_local_dir(cls, lora_path: str,
max_position_embeddings: Optional[int]) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")

with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)

def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}.")
if self.bias != "none" and not lora_config.bias_enabled:
error_msg.append(
"Adapter bias cannot be used without bias_enabled.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")
Loading
Loading