Skip to content

Commit 9cbc72d

Browse files
committed
Add fourierft_ifft2_norm to the parameters and other fixes
1 parent 8926877 commit 9cbc72d

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/peft/tuners/fourierft/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import dataclass, field
18-
from typing import Optional, Union
18+
from typing import Optional, Union, Literal
1919

2020
from peft.config import PeftConfig
2121
from peft.utils import PeftType
@@ -175,7 +175,7 @@ class FourierFTConfig(PeftConfig):
175175
},
176176
)
177177

178-
ifft2_norm: Optional[str] = field(
178+
ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field(
179179
default_factory='backward',
180180
metadata={
181181
"help": (
@@ -224,3 +224,9 @@ def __post_init__(self):
224224
# check for layers_to_transform and layers_pattern
225225
if self.layers_pattern and not self.layers_to_transform:
226226
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")
227+
228+
if (self.alpha is not None) and (self.n_frequency != 1000):
229+
raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...")
230+
231+
if (self.alpha is not None) and (self.n_frequency_pattern != {}):
232+
raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...")

src/peft/tuners/fourierft/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FourierFTLayer(BaseTunerLayer):
2727
# All names of layers that may contain (trainable) adapter weights
2828
adapter_layer_names = ("fourierft_spectrum",)
2929
# All names of other parameters that may contain adapter-related parameters
30-
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed")
30+
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed", "fourierft_ifft2_norm")
3131

3232
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
3333
self.base_layer = base_layer
@@ -39,6 +39,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
3939
# Mark the weight as unmerged
4040
self._disable_adapters = False
4141
self.merged_adapters = []
42+
self.fourierft_ifft2_norm = kwargs['ifft2_norm']
4243
self.kwargs = kwargs
4344

4445
base_layer = self.get_base_layer()
@@ -96,7 +97,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
9697
indices = self.indices[adapter].to(spectrum.device)
9798
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
9899
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
99-
delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter]
100+
delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
100101
return delta_weight.to(spectrum.dtype)
101102

102103
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
@@ -261,7 +262,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
261262
indices = self.indices[adapter].to(spectrum.device)
262263
dense_spectrum = torch.zeros(self.out_features*kH, self.in_features*kW, device=spectrum.device)
263264
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
264-
delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter]
265+
delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
265266
return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH))
266267

267268
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:

0 commit comments

Comments
 (0)