Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.platforms import current_platform


Expand Down Expand Up @@ -98,9 +99,13 @@ def dist_init_torch_only():
backend=backend)


class DummyLoRAModel(nn.Sequential, SupportsLoRA):
pass


@pytest.fixture
def dummy_model() -> nn.Module:
model = nn.Sequential(
model = DummyLoRAModel(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
Expand All @@ -121,12 +126,13 @@ def dummy_model() -> nn.Module:
("sampler", Sampler())
]))
model.config = MagicMock()
model.embedding_modules = {"lm_head": "lm_head"}
return model


@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
model = nn.Sequential(
model = DummyLoRAModel(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
Expand All @@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module:
("sampler", Sampler())
]))
model.config = MagicMock()
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model.embedding_modules = {"lm_head": "lm_head"}
return model


Expand Down
13 changes: 9 additions & 4 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
BAICHUAN_LORA_MODULES = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]


@pytest.mark.parametrize("lora_name", lora_lst)
Expand All @@ -22,12 +28,11 @@ def test_load_checkpoints(
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
Expand Down Expand Up @@ -90,12 +95,12 @@ def test_load_checkpoints(


def test_lora_weights_mapping(baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules

packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
Expand Down
7 changes: 5 additions & 2 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@

# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
LLAMA_LORA_MODULES = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]


@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
Expand Down
26 changes: 8 additions & 18 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform

EMBEDDING_MODULES = {
Expand Down Expand Up @@ -114,28 +113,23 @@ def create_packed_lora(

def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
model.supported_lora_modules = ["dense1", "layer1.dense2"]
model.packed_modules_mapping = {}
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0]))
model = manager.model

assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense2"),
RowParallelLinearWithLoRA)


@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
Expand Down Expand Up @@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device):

assert manager.device == device
assert manager.punica_wrapper.device == device
assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [
"dense1",
"dense2",
"lm_head",
"output",
]


@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
Expand Down Expand Up @@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
Expand Down Expand Up @@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model_lora = create_packed_lora(
1,
model,
Expand Down
21 changes: 11 additions & 10 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
Expand Down Expand Up @@ -332,15 +333,15 @@ def __init__(
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model)
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f"{self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
Expand Down Expand Up @@ -756,7 +757,7 @@ def create_lora_manager(
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "supported_lora_modules"):
if not hasattr(model, "packed_modules_mapping"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
Expand Down
26 changes: 26 additions & 0 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm.model_executor.layers.linear import LinearBase
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
Expand Down Expand Up @@ -68,6 +69,14 @@ def from_layer(layer: nn.Module,
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

# The Case for HFCompatibleLinear
if (hasattr(layer, "get_lora_class")
and layer.__class__.__name__ == "HFCompatibleLinear"):
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer


Expand Down Expand Up @@ -170,6 +179,23 @@ def is_subset(sub_list, full_list):
return False


def get_supported_lora_modules(model: nn.Module) -> List[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: Set[str] = set()
# step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
supported_lora_modules.add(name.split(".")[-1])
# step 2: get the embedding modules if the model's mbedding_modules
# is not empty.
if model.embedding_modules:
for name in model.embedding_modules:
supported_lora_modules.add(name)
return list(supported_lora_modules)


def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
Expand Down
8 changes: 5 additions & 3 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def create_lora_manager(

def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
model = self._adapter_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
supported_lora_modules = (
self._adapter_manager.supported_lora_modules)
packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping)
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
Expand All @@ -107,6 +108,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:

# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]

embedding_modules = {}
embedding_padding_modules = []

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"out_proj",
"gate_up_proj",
"c_proj",
"wte",
"lm_head",
]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
Expand Down
Loading