Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 86 additions & 53 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"

import logging
from typing import Dict, Set, Tuple
from typing import Dict, Iterable, Optional, Set, Tuple

import torch

Expand Down Expand Up @@ -53,6 +53,8 @@ def __init__(
lora_backend: str = "triton",
tp_size: int = 1,
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
Expand All @@ -62,6 +64,10 @@ def __init__(
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)

# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
Expand Down Expand Up @@ -153,7 +159,9 @@ def load_lora_adapter(
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."

try:
self.configs[lora_name] = LoRAConfig(lora_path)
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e:
success = False
error_message = (
Expand All @@ -168,6 +176,21 @@ def load_lora_adapter(
error_message=error_message,
)

def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""

incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
)

def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
Expand Down Expand Up @@ -214,7 +237,7 @@ def transfer_adapter_info(
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling

# Use pinned memory to avoid synchronizations during host-to-device transfer
Expand Down Expand Up @@ -319,7 +342,7 @@ def update_lora_info(self):
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
Expand Down Expand Up @@ -351,58 +374,66 @@ def init_state(self):
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}

# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None

def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).

This includes:
- Initializing LoRA adapters if they are not already loaded.
- Collect all LoRA weight names based on the current loaded adapters.
- Lazily monkey-patching the base model to use LoRA layers where applicable.
- Preparing the GPU buffer pool for active LoRA weights.
"""

# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
hf_target_module_names: Set[str] = set()
for config in self.configs.values():
hf_target_module_names.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])

# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()

def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.

# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
#
# Please note that the following update operations are "monotonic" by design, meaning that we update
# multiple places to support the new weight names when the first adapter targeting such weight names
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
# list of LoRA weight names is expected to be extremely finite and stable.
self.update_lora_weight_names(hf_target_module_names)
self.update_lora_modules(hf_target_module_names)
self.update_memory_buffers(max_lora_dim)

def update_lora_weight_names(self, hf_target_names: Set[str]):
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""

if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)

if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()]
)

self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
)

def update_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""

# Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)

def update_lora_adapters(self):
"""
Expand Down Expand Up @@ -434,33 +465,35 @@ def update_lora_adapters(self):
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "

def update_memory_buffers(self, max_lora_dim: int):
"""
Update the LoRA memory pool buffers based on the current LoRA configurations and update
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
are set or updated.
"""

self.memory_pool.init_buffers(
self.lora_weight_names, self.base_model, max_lora_dim
def update_memory_buffers(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
)

def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module

def update_lora_modules(self, hf_target_names: Set[str]):
def update_lora_modules(self):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
hf_target_names, self.base_model
self.target_modules, self.base_model
)

for module_name, module in self.base_model.named_modules():
Expand Down
65 changes: 47 additions & 18 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Callable, Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch

from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_normalized_lora_weight_names,
get_stacked_multiply,
get_weight_name,
)
Expand All @@ -25,13 +27,20 @@ def __init__(
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
base_model: torch.nn.Module,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank

# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names

# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
Expand All @@ -49,6 +58,33 @@ def __init__(
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch

self.init_buffers(base_model)

def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
"""
Check if the memory pool can support the given LoRA adapters.
"""

def _can_support(config: LoRAConfig) -> bool:
"""
Check if the memory pool can support a single LoRA adapter.
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
if not weights_a.issubset(
self.lora_weight_names[0]
) or not weights_b.issubset(self.lora_weight_names[1]):
return False
return True

if isinstance(config, LoRAConfig):
return _can_support(config)
else:
return all(_can_support(x) for x in config)

def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]:
Expand Down Expand Up @@ -82,25 +118,18 @@ def get_lora_B_shape(
max_lora_dim,
)

def init_buffers(
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
def init_buffers(self, base_model: torch.nn.Module):
device = next(base_model.parameters()).device

def update_buffer(
def init_buffer(
buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
):
new_weight_names = lora_weight_names - buffer.keys()
for module_name in new_weight_names:
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
for module_name in lora_weight_names:
lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [
torch.empty(
lora_shape,
Expand All @@ -110,15 +139,15 @@ def update_buffer(
for _ in range(self.num_layer)
]

update_buffer(
init_buffer(
self.A_buffer,
lora_weight_names[0],
self.lora_weight_names[0],
self.get_lora_A_shape,
)

update_buffer(
init_buffer(
self.B_buffer,
lora_weight_names[1],
self.lora_weight_names[1],
self.get_lora_B_shape,
)

Expand Down
17 changes: 12 additions & 5 deletions python/sglang/srt/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple

import torch

Expand Down Expand Up @@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError()


def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
"""
Mapping a target module name to names of the normalized LoRA weights.
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
Expand All @@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
stacked = params_mapping.get(name, ([name], [name]))
return stacked

result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result


def get_stacked_multiply(module_name: str) -> int:
Expand Down
Loading
Loading