-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Model] Activated LoRA #19710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Activated LoRA #19710
Changes from 34 commits
8a9610e
a68e70b
b254fb7
3897b1b
24ff376
412eacd
32098e4
fb6d28e
5f62d8b
f9396b0
6f36f6d
c6ffe8f
4a4b568
4cbef84
49a5bdc
a9ac26d
5c2e181
ceae7c7
99b8b60
477ab6e
91f39d1
5abbb78
0a20f2a
438ab6f
51edf96
a9d5986
6fbc108
cb373e9
24dfc4a
6c1b46a
b8444d9
4e513cc
b9df31f
643d893
199ee89
03b6480
76744da
6b83cc4
18397d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # LoRA Examples | ||
|
|
||
| This folder contains examples of offline inference using LoRA. | ||
|
|
||
| ## Multi-LoRA | ||
|
|
||
| This example shows how to use the multi-LoRA functionality: | ||
|
|
||
| ```bash | ||
| python examples/offline_inference/lora/multilora_inference.py | ||
| ``` | ||
|
|
||
| ## LoRA with Quantization | ||
|
|
||
| This example shows how to use LoRA with different quantization techniques: | ||
|
|
||
| ```bash | ||
| python examples/offline_inference/lora/lora_with_quantization_inference.py | ||
| ``` | ||
|
|
||
| ## Activated LoRA | ||
|
|
||
| This example how to use [activated LoRA](https://arxiv.org/abs/2504.12397): | ||
|
|
||
| ```bash | ||
| python examples/offline_inference/lora/activated_lora.py | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import time | ||
|
|
||
| import torch | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.lora.request import LoRARequest | ||
|
|
||
| BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" | ||
|
|
||
| ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" | ||
| invocation_string = "<|start_of_role|>certainty<|end_of_role|>" | ||
|
|
||
| # download your LoRA adapter to ~/.cache/huggingface/… | ||
| alora_path = snapshot_download(repo_id=ALORA_NAME) | ||
|
|
||
| print(alora_path) | ||
| ####################################### | ||
|
|
||
|
|
||
| llm = LLM( | ||
| model=BASE_NAME, | ||
| enable_lora=True, | ||
| enable_activated_lora=True, | ||
| dtype=torch.bfloat16, | ||
| max_lora_rank=64, | ||
| ) | ||
|
|
||
| prompts = [ | ||
| ( | ||
| "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n" | ||
| "<|start_of_role|>assistant<|end_of_role|>" | ||
| ), | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, max_tokens=600) | ||
|
|
||
| outputsBase = llm.generate( | ||
| prompts, | ||
| sampling_params, | ||
| use_tqdm=False, | ||
| ) | ||
| generated_text = [] | ||
| for output in outputsBase: | ||
| prompt = output.prompt | ||
| generated_text += [output.outputs[0].text] | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") | ||
|
|
||
| prompts_alora = [ | ||
| x + y + "<|end_of_text|>\n" + invocation_string | ||
| for x, y in zip(prompts, generated_text) | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, max_tokens=10) | ||
|
|
||
| t0 = time.time() | ||
| outputs = llm.generate( | ||
| prompts_alora, | ||
| sampling_params, | ||
| lora_request=LoRARequest("UQ_adapter", 1, alora_path), | ||
| use_tqdm=False, | ||
| ) | ||
| t = time.time() - t0 | ||
| print(f"Time: {t}") | ||
|
|
||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| # pylint: disable=unused-argument | ||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Optional, Union, cast | ||
| from typing import Optional, Union, cast | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
@@ -19,6 +19,8 @@ | |
| tensor_model_parallel_all_gather, | ||
| tensor_model_parallel_all_reduce) | ||
| from vllm.distributed.utils import divide | ||
| from vllm.forward_context import get_forward_context | ||
| from vllm.lora.punica_wrapper import PunicaWrapperBase | ||
| # yapf: disable | ||
| from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
| LinearBase, | ||
|
|
@@ -32,9 +34,6 @@ | |
| VocabParallelEmbedding) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.lora.punica_wrapper import PunicaWrapperBase | ||
|
|
||
|
|
||
| def _get_lora_device(base_layer: nn.Module) -> torch.device: | ||
| # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 | ||
|
|
@@ -1190,3 +1189,44 @@ def can_replace_layer( | |
| ) -> bool: | ||
| # Special handling for the LogitsProcessor. | ||
| return False | ||
|
|
||
|
|
||
| class ActivatedLoRAMixin: | ||
|
|
||
| base_layer: LinearBase | ||
| punica_wrapper: PunicaWrapperBase | ||
| lora_a_stacked: torch.tensor | ||
| lora_b_stacked: torch.tensor | ||
| lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] | ||
| output_slices: tuple[int, ...] | ||
|
|
||
| def apply(self, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| output = self.base_layer.quant_method.apply(self.base_layer, x, bias) | ||
|
|
||
| # In transformers backend, x and output have extra batch dimension like | ||
| # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), | ||
| # therefore we need to flatten the batch dimensions. | ||
| if x.ndim == 3 and output.ndim == 3: | ||
| output = output.flatten(0, 1) | ||
| x = x.flatten(0, 1) | ||
|
||
|
|
||
| # Extract aLoRA batch metadata from forward context | ||
| alora_metadata = get_forward_context().alora_metadata | ||
|
|
||
| mask1d = alora_metadata.mask1d | ||
| mask2d = mask1d.unsqueeze(1).to(output.dtype) | ||
|
|
||
| # Clone base layer output before running LoRA | ||
| # TODO(tdoublep): pass in mask1d and only operate on valid entries | ||
| orig_out = output.clone() | ||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Apply LoRA in‐place on `output`: | ||
| self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, | ||
| self.lora_b_stacked, | ||
| self.lora_bias_stacked, 1.0, | ||
| self.output_slices) | ||
| # Apply alora mask | ||
| final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) | ||
| return final_output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs | ||
| from vllm.inputs.parse import split_enc_dec_inputs | ||
| from vllm.inputs.preprocess import InputPreprocessor | ||
| from vllm.lora.peft_helper import PEFTHelper | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry | ||
| from vllm.multimodal.cache import processor_cache_from_config | ||
|
|
@@ -429,6 +430,36 @@ def process_inputs( | |
| identifier=decoder_mm_hashes[modality][idx], | ||
| mm_position=decoder_mm_positions[modality][idx])) | ||
|
|
||
| # Handle aLoRA invocation sequence if applicable. | ||
| if (self.lora_config and self.lora_config.activated_lora_enabled | ||
| and lora_request is not None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also check if it is a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't actually know that until we've called the PEFTHelper below to look at the adapter config |
||
|
|
||
| text_config = self.model_config.hf_config.get_text_config() | ||
|
|
||
| peft_helper = PEFTHelper.from_local_dir( | ||
| lora_request.lora_path, text_config.max_position_embeddings, | ||
| lora_request.tensorizer_config_dict) | ||
|
|
||
| if peft_helper.alora_invocation_tokens is not None: | ||
| invocation_tokens = peft_helper.alora_invocation_tokens | ||
| invocation_start = -1 | ||
| n = len(invocation_tokens) | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_ids = decoder_inputs["prompt_token_ids"] | ||
| if n > 0 and len(token_ids) >= n: | ||
| # scan backward for the last match | ||
| # (faster than full forward scan+max) | ||
| for idx in range(len(token_ids) - n, -1, -1): | ||
| if token_ids[idx:idx + n] == invocation_tokens: | ||
| # weights activated after start | ||
| invocation_start = idx | ||
| break | ||
| if invocation_start == -1: | ||
| raise ValueError( | ||
| "Invocation sequence not found in prompt " | ||
| f"for request '{request_id}'. aLoRA models require the " | ||
| "invocation tokens to be present in the input.") | ||
| lora_request.invocation_start = invocation_start | ||
|
|
||
| return decoder_inputs.get("prompt"), EngineCoreRequest( | ||
| request_id=request_id, | ||
| prompt_token_ids=decoder_inputs["prompt_token_ids"], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.