@@ -27,7 +27,7 @@ class FourierFTLayer(BaseTunerLayer):
2727 # All names of layers that may contain (trainable) adapter weights
2828 adapter_layer_names = ("fourierft_spectrum" ,)
2929 # All names of other parameters that may contain adapter-related parameters
30- other_param_names = ("fourierft_n_frequency" , "fourierft_scaling" , "fourierft_random_loc_seed" )
30+ other_param_names = ("fourierft_n_frequency" , "fourierft_scaling" , "fourierft_random_loc_seed" , "fourierft_ifft2_norm" )
3131
3232 def __init__ (self , base_layer : nn .Module , ** kwargs ) -> None :
3333 self .base_layer = base_layer
@@ -39,6 +39,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
3939 # Mark the weight as unmerged
4040 self ._disable_adapters = False
4141 self .merged_adapters = []
42+ self .fourierft_ifft2_norm = kwargs ['ifft2_norm' ]
4243 self .kwargs = kwargs
4344
4445 base_layer = self .get_base_layer ()
@@ -96,7 +97,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
9697 indices = self .indices [adapter ].to (spectrum .device )
9798 dense_spectrum = torch .zeros (self .out_features , self .in_features , device = spectrum .device )
9899 dense_spectrum [indices [0 , :], indices [1 , :]] = spectrum .float ()
99- delta_weight = torch .fft .ifft2 (dense_spectrum , norm = self .kwargs [ 'ifft2_norm' ] ).real * self .fourierft_scaling [adapter ]
100+ delta_weight = torch .fft .ifft2 (dense_spectrum , norm = self .fourierft_ifft2_norm ).real * self .fourierft_scaling [adapter ]
100101 return delta_weight .to (spectrum .dtype )
101102
102103 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
@@ -261,7 +262,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
261262 indices = self .indices [adapter ].to (spectrum .device )
262263 dense_spectrum = torch .zeros (self .out_features * kH , self .in_features * kW , device = spectrum .device )
263264 dense_spectrum [indices [0 , :], indices [1 , :]] = spectrum .float ()
264- delta_weight = torch .fft .ifft2 (dense_spectrum , norm = self .kwargs [ 'ifft2_norm' ] ).real * self .fourierft_scaling [adapter ]
265+ delta_weight = torch .fft .ifft2 (dense_spectrum , norm = self .fourierft_ifft2_norm ).real * self .fourierft_scaling [adapter ]
265266 return torch .reshape (delta_weight , (self .out_features , self .in_features , kW , kH ))
266267
267268 def forward (self , x : torch .Tensor , * args : Any , ** kwargs : Any ) -> torch .Tensor :
0 commit comments