2525
2626class FourierFTLayer (BaseTunerLayer ):
2727 # All names of layers that may contain (trainable) adapter weights
28- adapter_layer_names = ("fourierft_spectrum" ,)
28+ adapter_layer_names = ("fourierft_spectrum" , "fourierft_scaling" )
2929 # All names of other parameters that may contain adapter-related parameters
30- other_param_names = ("fourierft_n_frequency" , "fourierft_scaling" , " fourierft_random_loc_seed" )
30+ other_param_names = ("fourierft_n_frequency" , "fourierft_random_loc_seed" )
3131
32- def __init__ (self , base_layer : nn .Module , alpha , ** kwargs ) -> None :
32+ def __init__ (self , base_layer : nn .Module , ** kwargs ) -> None :
3333 self .base_layer = base_layer
3434 self .fourierft_n_frequency = {}
35- self .fourierft_scaling = {}
35+ self .fourierft_scaling = nn . ParameterDict ({})
3636 self .fourierft_spectrum = nn .ParameterDict ({})
3737 self .indices = {}
3838 self .fourierft_random_loc_seed = {}
@@ -55,7 +55,7 @@ def __init__(self, base_layer: nn.Module, alpha, **kwargs) -> None:
5555 raise ValueError (f"Unsupported layer type { type (base_layer )} " )
5656
5757 def update_layer (
58- self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , inference_mode : bool = False , ** kwargs
58+ self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling , inference_mode : bool = False , ** kwargs
5959 ):
6060 if n_frequency <= 0 :
6161 raise ValueError (f"`n_frequency` should be a positive integer value but the value passed is { n_frequency } " )
@@ -73,7 +73,7 @@ def update_layer(
7373 self .indices [adapter_name ] = torch .stack (
7474 [self .indices [adapter_name ] // self .in_features , self .indices [adapter_name ] % self .in_features ], dim = 0
7575 )
76- self .fourierft_scaling [adapter_name ] = scaling
76+ self .fourierft_scaling [adapter_name ] = nn . Parameter ( torch . tensor ( scaling , dtype = torch . float32 ), requires_grad = dynamic_scaling )
7777 # Actual trainable parameters
7878 self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency ), requires_grad = True )
7979
@@ -107,21 +107,22 @@ def __init__(
107107 n_frequency : int = 1000 ,
108108 alpha : float = None ,
109109 scaling : float = 150.0 ,
110+ dynamic_scaling : bool = False ,
110111 fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
111112 init_weights : Union [bool , str ] = False ,
112113 random_loc_seed : int = 777 ,
113114 ** kwargs ,
114115 ) -> None :
115116 super ().__init__ ()
116- FourierFTLayer .__init__ (self , base_layer , alpha , ** kwargs )
117+ FourierFTLayer .__init__ (self , base_layer , ** kwargs )
117118
118119 # apply alpha patch
119120 if alpha :
120121 n_frequency = int (alpha * self .in_features * self .out_features )
121122
122123 self .fan_in_fan_out = fan_in_fan_out
123124 self ._active_adapter = adapter_name
124- self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed )
125+ self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling )
125126
126127 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
127128 """
@@ -210,29 +211,30 @@ def __init__(
210211 n_frequency : int = 1000 ,
211212 alpha : float = None ,
212213 scaling : float = 150.0 ,
214+ dynamic_scaling : bool = False ,
213215 fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
214216 init_weights : Union [bool , str ] = False ,
215217 random_loc_seed : int = 777 ,
216218 ** kwargs ,
217219 ) -> None :
218220 super ().__init__ ()
219- FourierFTLayer .__init__ (self , base_layer , alpha , ** kwargs )
221+ FourierFTLayer .__init__ (self , base_layer , ** kwargs )
220222
221223 # apply alpha patch
222224 if alpha :
223225 n_frequency = int (alpha * self .in_features * self .out_features )
224-
226+
225227 self .fan_in_fan_out = fan_in_fan_out
226228 self ._active_adapter = adapter_name
227229 self .kW = base_layer .kernel_size [0 ]
228230 self .kH = base_layer .kernel_size [1 ]
229231 self .stride = base_layer .stride
230232 self .padding = base_layer .padding
231- self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed )
233+ self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling )
232234
233235
234236 def update_layer (
235- self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , inference_mode : bool = False , ** kwargs
237+ self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling , inference_mode : bool = False , ** kwargs
236238 ):
237239 if n_frequency <= 0 :
238240 raise ValueError (f"`n_frequency` should be a positive integer value but the value passed is { n_frequency } " )
@@ -241,6 +243,7 @@ def update_layer(
241243 f"`n_frequency` should be less than or equal to the product of the input and output dimensions "
242244 f"but the value passed is { n_frequency } and the product is { self .in_features * self .out_features } "
243245 )
246+
244247 self .fourierft_n_frequency [adapter_name ] = n_frequency
245248 self .fourierft_random_loc_seed [adapter_name ] = random_loc_seed
246249 self .indices [adapter_name ] = torch .randperm (
@@ -250,7 +253,7 @@ def update_layer(
250253 self .indices [adapter_name ] = torch .stack (
251254 [self .indices [adapter_name ] // self .in_features , self .indices [adapter_name ] % self .in_features ], dim = 0
252255 )
253- self .fourierft_scaling [adapter_name ] = scaling
256+ self .fourierft_scaling [adapter_name ] = nn . Parameter ( torch . tensor ( scaling , dtype = torch . float32 ), requires_grad = dynamic_scaling )
254257 # Actual trainable parameters
255258 self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency , self .kW , self .kH ), requires_grad = True )
256259
0 commit comments