Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import vllm
import vllm.config
from vllm.lora.request import LoRARequest

from ..utils import create_new_process_for_each_test, multi_gpu_test
Expand Down Expand Up @@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_num_seqs=16,
max_lora_rank=64,
trust_remote_code=True,
)
Expand All @@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_lora_rank=64,
max_num_seqs=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
Expand All @@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
# more GPU memory causing vLLM to OOM
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
gpu_memory_utilization=0.85,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
Expand Down
9 changes: 8 additions & 1 deletion tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import subprocess
import sys

import pytest

import vllm
import vllm.config
from vllm import LLM
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
Expand Down Expand Up @@ -100,14 +103,18 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =


@create_new_process_for_each_test()
def test_llama_lora(sql_lora_files):
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
llm = vllm.LLM(
MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True,
# also test odd max_num_seqs
max_num_seqs=13,
max_loras=4,
compilation_config=vllm.config.CompilationConfig(
cudagraph_specialize_lora=cudagraph_specialize_lora,
),
)
generate_and_test(llm, sql_lora_files)

Expand Down
8 changes: 8 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
for all cases, incurring the overhead of running LoRA ops even when no
adapters are active. Setting this to True will remove this overhead at the
cost of increased startup time and slightly higher memory usage.
When `enable_lora` is False, this option has no effect.
"""

use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
Expand Down
8 changes: 7 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""

@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
)


def _compute_sp_num_tokens(
Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/ops/triton_ops/lora_shrink_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def _lora_shrink(
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1

output_tensor.zero_()

(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device)
)
Expand Down
34 changes: 21 additions & 13 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,18 @@ def add_lora_linear(

assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)

if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
)
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
)

self.add_shrink(
buffer, # type: ignore
x,
Expand Down Expand Up @@ -260,10 +263,15 @@ def add_lora_logits(
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)

assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)

lora_shrink(
x,
Expand Down
24 changes: 20 additions & 4 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import product

from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
Expand Down Expand Up @@ -67,14 +68,27 @@ def initialize_cudagraph_keys(
):
# This should be called only after attention backend is initialized.

# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]

# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
for bs, has_lora in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
)

# if decode cudagraph mode is FULL, and we don't already have mixed
Expand All @@ -92,10 +106,12 @@ def initialize_cudagraph_keys(
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True),
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
)
self.keys_initialized = True

Expand Down
33 changes: 27 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast

import numpy as np
Expand Down Expand Up @@ -2469,7 +2470,9 @@ def execute_model(
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
)
batch_descriptor = BatchDescriptor(
num_tokens=num_input_tokens, uniform_decode=uniform_decode
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
Expand Down Expand Up @@ -3193,6 +3196,7 @@ def _dummy_run(
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
Expand All @@ -3215,6 +3219,7 @@ def _dummy_run(
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
"""
assert (
cudagraph_runtime_mode is None
Expand Down Expand Up @@ -3364,7 +3369,7 @@ def _dummy_run(
attn_metadata[layer_name] = attn_metadata_i

with self.maybe_dummy_run_with_lora(
self.lora_config, num_scheduled_tokens, remove_lora
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens
Expand Down Expand Up @@ -3411,6 +3416,7 @@ def _dummy_run(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
Expand Down Expand Up @@ -3769,10 +3775,21 @@ def freeze_gc():
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None

if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]

if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()

compilation_cases = list(reversed(self.cudagraph_batch_sizes))
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
Expand All @@ -3793,7 +3810,9 @@ def freeze_gc():
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
Expand Down Expand Up @@ -3823,7 +3842,7 @@ def freeze_gc():

def _capture_cudagraphs(
self,
compilation_cases: list[int],
compilation_cases: list[tuple[int, bool]],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
Expand All @@ -3844,7 +3863,7 @@ def _capture_cudagraphs(
)

# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for num_tokens, activate_lora in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
Expand Down Expand Up @@ -3875,6 +3894,7 @@ def _capture_cudagraphs(
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self._dummy_run(
num_tokens,
Expand All @@ -3883,6 +3903,7 @@ def _capture_cudagraphs(
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self.maybe_remove_all_loras(self.lora_config)

Expand Down
17 changes: 14 additions & 3 deletions vllm/v1/worker/lora_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def maybe_setup_dummy_loras(

@contextmanager
def maybe_select_dummy_loras(
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
):
if lora_config is None:
yield
Expand All @@ -133,7 +136,12 @@ def maybe_select_dummy_loras(

# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1
if activate_lora:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)

# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
Expand All @@ -159,11 +167,14 @@ def maybe_dummy_run_with_lora(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, activate_lora
),
):
yield

Expand Down