Skip to content

Commit daa328d

Browse files
committed
FourierFT: add alpha for dynamic n_frequency
1 parent 7cfbb83 commit daa328d

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

src/peft/tuners/fourierft/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ class FourierFTConfig(PeftConfig):
185185
},
186186
)
187187

188+
alpha: float = field(
189+
default=None,
190+
metadata={
191+
"help": (
192+
"The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)"
193+
)
194+
},
195+
)
196+
188197
def __post_init__(self):
189198
super().__post_init__()
190199
self.peft_type = PeftType.FOURIERFT

src/peft/tuners/fourierft/layer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class FourierFTLayer(BaseTunerLayer):
2929
# All names of other parameters that may contain adapter-related parameters
3030
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed")
3131

32-
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
32+
def __init__(self, base_layer: nn.Module, alpha, **kwargs) -> None:
3333
self.base_layer = base_layer
3434
self.fourierft_n_frequency = {}
3535
self.fourierft_scaling = {}
@@ -49,9 +49,14 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
4949
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
5050
)
5151
elif isinstance(base_layer, nn.Conv2d):
52-
pass
52+
self.in_features = base_layer.in_channels
53+
self.out_features = base_layer.out_channels
5354
else:
5455
raise ValueError(f"Unsupported layer type {type(base_layer)}")
56+
57+
# apply alpha patch
58+
if alpha:
59+
kwargs['n_frequency'] = int(alpha * self.in_features * self.out_features)
5560

5661
def update_layer(
5762
self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, inference_mode: bool = False, **kwargs
@@ -104,14 +109,16 @@ def __init__(
104109
base_layer,
105110
adapter_name: str,
106111
n_frequency: int = 1000,
112+
alpha: float = None,
107113
scaling: float = 150.0,
108114
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
109115
init_weights: Union[bool, str] = False,
110116
random_loc_seed: int = 777,
111117
**kwargs,
112118
) -> None:
113119
super().__init__()
114-
FourierFTLayer.__init__(self, base_layer, **kwargs)
120+
FourierFTLayer.__init__(self, base_layer, alpha, **kwargs)
121+
115122
self.fan_in_fan_out = fan_in_fan_out
116123
self._active_adapter = adapter_name
117124
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)
@@ -201,18 +208,17 @@ def __init__(
201208
base_layer,
202209
adapter_name: str,
203210
n_frequency: int = 1000,
211+
alpha: float = None,
204212
scaling: float = 150.0,
205213
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
206214
init_weights: Union[bool, str] = False,
207215
random_loc_seed: int = 777,
208216
**kwargs,
209217
) -> None:
210218
super().__init__()
211-
FourierFTLayer.__init__(self, base_layer, **kwargs)
219+
FourierFTLayer.__init__(self, base_layer, alpha, **kwargs)
212220
self.fan_in_fan_out = fan_in_fan_out
213221
self._active_adapter = adapter_name
214-
self.in_features = base_layer.in_channels
215-
self.out_features = base_layer.out_channels
216222
self.kW = base_layer.kernel_size[0]
217223
self.kH = base_layer.kernel_size[1]
218224
self.stride = base_layer.stride

src/peft/tuners/fourierft/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ def _create_and_replace(
9696

9797
n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency)
9898
scaling = fourierft_config.scaling
99+
alpha = fourierft_config.alpha
99100
random_loc_seed = fourierft_config.random_loc_seed
100101
bias = hasattr(target, "bias") and target.bias is not None
101102
kwargs = {
102103
"n_frequency": n_frequency,
104+
"alpha": alpha,
103105
"scaling": scaling,
104106
"fan_in_fan_out": fourierft_config.fan_in_fan_out,
105107
"init_weights": fourierft_config.init_weights,

0 commit comments

Comments
 (0)