Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ class ModelArgument:
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})

use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)
# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
Expand Down
1 change: 1 addition & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def neft_post_hook(module, input, output):
dtype=dtype,
do_qat=quant_args.do_qat,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
)
model = LoRAModel(model, lora_config)
else:
Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional, Union

from ...utils.env import LORA_CONFIG_NAME
from ...utils.log import logger


@dataclass
Expand Down Expand Up @@ -75,6 +76,20 @@ class LoRAConfig:
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)

def __post_init__(self):
if self.use_quick_lora and self.lora_dropout > 0:
logger.warning(
"Quick LoRa is enabled, but lora_dropout is set to a non-zero value. "
"We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies."
)
self.use_quick_lora = False

@property
def __dict__(self):
Expand Down
119 changes: 90 additions & 29 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
RowParallelLinear,
)

from .lora_quick_layers import quick_lora


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
Expand All @@ -35,6 +37,7 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -68,6 +71,11 @@ def __init__(

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
Expand All @@ -86,9 +94,12 @@ def eval(self):
self.merged = True

def forward(self, input: paddle.Tensor, *args, **kwargs):
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
if self.use_quick_lora:
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)
else:
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
return result

def extra_repr(self):
Expand All @@ -105,6 +116,7 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -146,6 +158,11 @@ def __init__(

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
Expand All @@ -169,30 +186,48 @@ def forward(self, x: paddle.Tensor):
else:
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)

output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(
if self.use_quick_lora:
result_mp = quick_lora(
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_row=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(
input_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output
return output

def extra_repr(self):
Expand All @@ -210,6 +245,7 @@ def __init__(
lora_dropout: float = 0.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -249,6 +285,11 @@ def __init__(

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
Expand All @@ -267,14 +308,34 @@ def eval(self):
self.merged = True

def forward(self, input: paddle.Tensor):
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
if self.is_mp:
input_mp = mp_ops._c_identity(
input,
group=self.model_parallel_group,
)
else:
input_mp = input
if self.use_quick_lora:
# Use the quick lora implementation
result_mp = quick_lora(
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_column=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
else:
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_dropout=lora_config.lora_dropout,
merge_weights=lora_config.merge_weights,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
)
if isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
Expand Down Expand Up @@ -418,6 +419,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
)
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
Expand All @@ -438,6 +440,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
merge_weights=lora_config.merge_weights,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
Expand Down
Loading