diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index f4f151180dec..c43de9d45afe 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import vllm +import vllm.config from vllm.lora.request import LoRARequest from ..utils import create_new_process_for_each_test, multi_gpu_test @@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, + max_num_seqs=16, max_lora_rank=64, trust_remote_code=True, ) @@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files): def test_chatglm3_lora_tp4(chatglm3_lora_files): llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, max_lora_rank=64, + max_num_seqs=16, tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=False, + compilation_config=vllm.config.CompilationConfig( # Avoid OOM + cudagraph_specialize_lora=False, + ), ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) @@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): # more GPU memory causing vLLM to OOM llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, max_lora_rank=64, tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, gpu_memory_utilization=0.85, + compilation_config=vllm.config.CompilationConfig( # Avoid OOM + cudagraph_specialize_lora=False, + ), ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index e1d6a8674a01..7bbd1e364d19 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -3,7 +3,10 @@ import subprocess import sys +import pytest + import vllm +import vllm.config from vllm import LLM from vllm.lora.request import LoRARequest from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -100,7 +103,8 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = @create_new_process_for_each_test() -def test_llama_lora(sql_lora_files): +@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False]) +def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool): llm = vllm.LLM( MODEL_PATH, tokenizer=sql_lora_files, @@ -108,6 +112,9 @@ def test_llama_lora(sql_lora_files): # also test odd max_num_seqs max_num_seqs=13, max_loras=4, + compilation_config=vllm.config.CompilationConfig( + cudagraph_specialize_lora=cudagraph_specialize_lora, + ), ) generate_and_test(llm, sql_lora_files) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2aaf4ba51f4a..61e73414335a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -366,6 +366,14 @@ class CompilationConfig: minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= FULL_AND_PIECEWISE instead. """ + cudagraph_specialize_lora: bool = True + """Whether to create separate cuda graphs for cases with and without active + LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used + for all cases, incurring the overhead of running LoRA ops even when no + adapters are active. Setting this to True will remove this overhead at the + cost of increased startup time and slightly higher memory usage. + When `enable_lora` is False, this option has no effect. + """ use_inductor_graph_partition: bool = False """Use inductor graph partition to split the graph at cudagraph_unsafe ops. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 484de15040c2..ef37cf862c9f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple): False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ + has_lora: bool = False + """ + Whether this batch has active LoRA adapters. + """ @property def non_uniform(self) -> "BatchDescriptor": """ Return a non-uniform version of current batch descriptor. """ - return BatchDescriptor(self.num_tokens, uniform_decode=False) + return BatchDescriptor( + self.num_tokens, uniform_decode=False, has_lora=self.has_lora + ) def _compute_sp_num_tokens( diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8c58915e3f79..8d126197f83e 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -169,6 +169,8 @@ def _lora_shrink( assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + output_tensor.zero_() + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( _get_lora_a_ptr(lora_a_weights, inputs.device) ) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 44a5443c3065..cdb0e6708290 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -205,15 +205,18 @@ def add_lora_linear( assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros( # type: ignore - (len(output_slices), x.size(0), r), - dtype=torch.float32, - device=x.device, - ) + assert buffer is None, ( + "To minimize overhead, the buffer should be created by " + ".add_lora_linear() instead of being passed in." + ) + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + # Note: buffer is zeroed inside the shrink op + buffer = torch.empty( + (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device + ) + self.add_shrink( buffer, # type: ignore x, @@ -260,10 +263,15 @@ def add_lora_logits( y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) r = lora_b_stacked.size(-1) - if buffer is None: - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + + assert buffer is None, ( + "To minimize overhead, the buffer should be created by " + ".add_lora_linear() instead of being passed in." + ) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + # Note: buffer is zeroed inside the shrink op + buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device) lora_shrink( x, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index a12704b664c3..b480ac78f23c 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from itertools import product from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor @@ -67,14 +68,27 @@ def initialize_cudagraph_keys( ): # This should be called only after attention backend is initialized. + # LoRA activation cases to specialize the cuda graphs on + if self.vllm_config.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] + # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs in self.compilation_config.cudagraph_capture_sizes: + for bs, has_lora in product( + self.compilation_config.cudagraph_capture_sizes, lora_cases + ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor(num_tokens=bs, uniform_decode=False), + BatchDescriptor( + num_tokens=bs, uniform_decode=False, has_lora=has_lora + ), ) # if decode cudagraph mode is FULL, and we don't already have mixed @@ -92,10 +106,12 @@ def initialize_cudagraph_keys( for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] - for bs in cudagraph_capture_sizes_for_decode: + for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True), + BatchDescriptor( + num_tokens=bs, uniform_decode=True, has_lora=has_lora + ), ) self.keys_initialized = True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9a37803a7fc6..5603b05e9918 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,6 +8,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy +from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -2469,7 +2470,9 @@ def execute_model( num_scheduled_tokens == self.input_batch.num_reqs * max_query_len ) batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, uniform_decode=uniform_decode + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, ) cudagraph_runtime_mode, batch_descriptor = ( self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) @@ -3193,6 +3196,7 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, + activate_lora: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -3215,6 +3219,7 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. """ assert ( cudagraph_runtime_mode is None @@ -3364,7 +3369,7 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora( - self.lora_config, num_scheduled_tokens, remove_lora + self.lora_config, num_scheduled_tokens, activate_lora, remove_lora ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens @@ -3411,6 +3416,7 @@ def _dummy_run( BatchDescriptor( num_tokens=num_tokens_after_padding, uniform_decode=uniform_decode, + has_lora=activate_lora and self.lora_config is not None, ) ) if not is_profile @@ -3769,10 +3775,21 @@ def freeze_gc(): start_free_gpu_memory = torch.cuda.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None + + if self.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + compilation_cases = list( + product(reversed(self.cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, @@ -3793,7 +3810,9 @@ def freeze_gc(): for x in self.cudagraph_batch_sizes if max_num_tokens >= x >= self.uniform_decode_query_len ] - compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list( + product(reversed(decode_cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, @@ -3823,7 +3842,7 @@ def freeze_gc(): def _capture_cudagraphs( self, - compilation_cases: list[int], + compilation_cases: list[tuple[int, bool]], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, ): @@ -3844,7 +3863,7 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: + for num_tokens, activate_lora in compilation_cases: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -3875,6 +3894,7 @@ def _capture_cudagraphs( allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, + activate_lora=activate_lora, ) self._dummy_run( num_tokens, @@ -3883,6 +3903,7 @@ def _capture_cudagraphs( allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, + activate_lora=activate_lora, ) self.maybe_remove_all_loras(self.lora_config) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 3057d3dc00e8..372bc0a05673 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -120,7 +120,10 @@ def maybe_setup_dummy_loras( @contextmanager def maybe_select_dummy_loras( - self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray + self, + lora_config: LoRAConfig | None, + num_scheduled_tokens: np.ndarray, + activate_lora: bool = True, ): if lora_config is None: yield @@ -133,7 +136,12 @@ def maybe_select_dummy_loras( # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 + if activate_lora: + prompt_lora_mapping = ( + np.arange(num_reqs, dtype=np.int32) % num_loras + ) + 1 + else: + prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) @@ -159,11 +167,14 @@ def maybe_dummy_run_with_lora( self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, + activate_lora: bool = True, remove_lora: bool = True, ): with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), - self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), + self.maybe_select_dummy_loras( + lora_config, num_scheduled_tokens, activate_lora + ), ): yield