Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 131 additions & 169 deletions python/sglang/srt/lora/lora_manager.py

Large diffs are not rendered by default.

124 changes: 124 additions & 0 deletions python/sglang/srt/lora/lora_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import asyncio
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4


@dataclass(frozen=True, slots=True)
class LoRARef:
"""
Reference record for a LoRA model.

This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""

lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None

def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")

def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"


class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.

TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""

def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)

# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})

async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.

Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref

async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.

Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]

return lora_ref.lora_id

async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.

TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""

async def _acquire_single(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id

if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")
4 changes: 2 additions & 2 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def prepare_lora_batch(
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
Expand Down Expand Up @@ -186,7 +186,7 @@ def load_lora_weight_to_buffer(
uid: str,
buffer_id: int,
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
Expand Down
22 changes: 20 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
Expand Down Expand Up @@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
lora_name: str
# The path of loading.
lora_path: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None

def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
)


@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
lora_name: Optional[str] = None
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None

def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
)


@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
20 changes: 3 additions & 17 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
Expand Down Expand Up @@ -1706,13 +1706,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)

if self.lora_paths:
if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs])

# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
self.enable_lora
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
Expand Down Expand Up @@ -2466,12 +2466,6 @@ def load_lora_adapter(
"""In-place loading a new lora adapter from disk or huggingface."""

result = self.tp_worker.load_lora_adapter(recv_req)

if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result

def unload_lora_adapter(
Expand All @@ -2480,14 +2474,6 @@ def unload_lora_adapter(
"""Unload the lora adapter."""

result = self.tp_worker.unload_lora_adapter(recv_req)

if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result

def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
Expand Down
47 changes: 24 additions & 23 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
Expand Down Expand Up @@ -244,9 +245,7 @@ def __init__(

# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})

# Store states
self.no_create_loop = False
Expand Down Expand Up @@ -523,6 +522,10 @@ async def _tokenize_one_request(
else:
mm_inputs = None

if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)

self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
Expand Down Expand Up @@ -574,8 +577,6 @@ def _validate_one_request(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)

def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
Expand Down Expand Up @@ -689,21 +690,6 @@ def _validate_batch_tokenization_constraints(
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)

def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)

def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -1054,8 +1040,18 @@ async def load_lora_adapter(
)

async with self.model_update_lock.writer_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)

# Register the new adapter in the registry.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
if result.success:
await self.lora_registry.register(new_adapter)

return result

async def unload_lora_adapter(
Expand All @@ -1069,6 +1065,10 @@ async def unload_lora_adapter(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)

assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"

# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
Expand All @@ -1080,8 +1080,9 @@ async def unload_lora_adapter(
)

async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters

return result

async def get_weights_by_name(
Expand Down Expand Up @@ -1309,7 +1310,7 @@ def dump_requests_before_crash(self):
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
)

os.makedirs(os.path.dirname(filename), exist_ok=True)
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return parameter

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result

def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
Loading
Loading