Skip to content
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


@pytest.fixture(scope="session")
def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
41 changes: 29 additions & 12 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@
RowParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, LoRAMapping)
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]


def test_from_lora_tensors(sql_lora_files):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
lora_model = LoRAModel.from_lora_tensors(1,
8,
16,
tensors,
"cuda",
embeddings=new_embeddings)
lora_model = LoRAModel.from_lora_tensors(
1,
8,
16,
tensors,
"cuda",
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
for module_name, lora in lora_model.loras.items():
assert lora.module_name == module_name
assert lora.rank == 8
Expand Down Expand Up @@ -289,8 +299,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

mapping = LoRAMapping([], [])
Expand Down Expand Up @@ -362,8 +373,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

mapping = LoRAMapping([], [])
Expand Down Expand Up @@ -444,7 +456,12 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
manager = LoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
["gate_up_proj"])
["gate_up_proj"], {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
})
model = manager.model

assert isinstance(model.get_submodule("gate_up_proj"),
Expand Down
53 changes: 53 additions & 0 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"


def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size,
worker_use_ray=True)

expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
]

assert do_sample(llm, mixtral_lora_files,
lora_id=1) == expected_lora_output
assert do_sample(llm, mixtral_lora_files,
lora_id=2) == expected_lora_output
100 changes: 48 additions & 52 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,6 @@

logger = logging.getLogger(__name__)

# TODO: The mappings below should be moved to individual model classes.

PACKED_MODULES_CFG = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

TARGET_MODULES_QKV = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]

EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

_GLOBAL_LORA_ID = 0


Expand Down Expand Up @@ -169,6 +139,8 @@ def from_lora_tensors(
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl()
Expand All @@ -179,11 +151,11 @@ def from_lora_tensors(
lora_embeddings_tensor = None
if embeddings:
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name),
(k for k in embedding_modules if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
EMBEDDING_MODULES[embeddings_module]].to(
embedding_modules[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
Expand All @@ -201,7 +173,7 @@ def from_lora_tensors(
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
if any(name in module_name
for name in EMBEDDING_PADDING_MODULES
for name in embedding_padding_modules
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
Expand All @@ -218,12 +190,15 @@ def from_lora_tensors(

@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
cls,
lora_dir: str,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
Expand Down Expand Up @@ -260,6 +235,8 @@ def from_local_checkpoint(
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)


Expand All @@ -273,8 +250,8 @@ def __init__(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
lora_target_modules: Optional[Union[str, List[str]]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
):
"""Create a LoRAModelManager and adapter for a given model.

Expand Down Expand Up @@ -320,11 +297,22 @@ def __init__(
self.indices_len = [None] * 4

self.model: nn.Module = model
self.lora_target_modules: List[str] = ([
lora_target_modules
] if isinstance(lora_target_modules, str) else lora_target_modules)
self.lora_target_modules = copy.deepcopy(lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
# allow overriding the target modules and mapping with initialization
if lora_target_modules:
self.lora_target_modules: List[str] = ([
lora_target_modules
] if isinstance(lora_target_modules, str) else lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(
packed_modules_mapping) if packed_modules_mapping else {}
elif hasattr(self.model, "supports_lora") and self.model.supports_lora:
assert hasattr(self.model, "lora_target_modules") and hasattr(
self.model, "packed_modules_mapping"), (
f"{self.model} model must have lora_target_modules and "
"packed_modules_mapping to support lora")
self.lora_target_modules = copy.deepcopy(
self.model.lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
Expand Down Expand Up @@ -468,7 +456,11 @@ def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module

def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
def create_dummy_lora(
self,
lora_id: int,
rank: int,
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
Expand All @@ -477,7 +469,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
if parts[-1] in EMBEDDING_MODULES:
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
Expand Down Expand Up @@ -586,8 +578,8 @@ def __init__(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
lora_target_modules: Optional[Union[str, List[str]]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, lora_target_modules,
Expand Down Expand Up @@ -637,7 +629,8 @@ def create_lora_manager(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
target_modules: Optional[Union[str, List[str]]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
Expand All @@ -649,6 +642,9 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
lora_target_modules=target_modules,
lora_target_modules=target_modules
if target_modules else model.lora_target_modules,
packed_modules_mapping=packed_modules_mapping
if packed_modules_mapping else model.packed_modules_mapping,
**kwargs)
return lora_manager
Loading