Skip to content

Commit 634bd19

Browse files
FIX: setting requires_grad on adapter layers (#905)
* [WIP] Fix setting requires_grad on adapter layers This is an alternative to #900, resolves #899. Description Currently, we don't handle setting requires_grad on adapter layers really well. The main issue is that it can be set to True on adapter parameters that are not being used, e.g. the original_module in ModulesToSaveWrapper or inactive adapters in LoRA. Normally, this is not a big issue, except maybe if we want to correctly count the number of trainable parameters. However, when training with DistributedDataParallel, this results in errors, as PyTorch thinks that all parameters with requires_grad=True should participate in the loss computation, but those mentioned parameters don't. For that reason, training with DDP currently fails when using modules_to_save or multiple adapters. Implementation This turned out to be more complicated than I initially thought. The logic for setting requires_grad is all over the place, it was hard to encapsulate the logic and I only succeeded partially. As is, this PR is more complex than the one it tries to supersede, #900, but it is also "more correct". Tests were added to check whether requires_grad is set correctly. There are (so far) no tests for whether DDP indeed works, they could be added with multi-GPU. I did, however, test an early stage of this PR with DDP and setting requires_grad correctly will indeed fix the DDP error. DONE/TODO - [x] ModulesToSaveWrapper - [x] LoRA - [ ] IA³ - [ ] AdaLora Since some tuners are not implemented yet, tests are expected to fail. Check the new tests at the bottom of test_custom.py, those should pass. * Refactor: move more requires_grad machinery to ABC * [skip ci] [WIP] Add requires_grad logic to IA³ * Add AdaLora * Fix some minor issues * Make style
1 parent 1af8ca4 commit 634bd19

File tree

15 files changed

+513
-60
lines changed

15 files changed

+513
-60
lines changed

src/peft/tuners/adalora/bnb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151

5252
init_lora_weights = kwargs.pop("init_lora_weights", True)
5353
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
54-
self.active_adapter = adapter_name
54+
self.set_adapter(adapter_name)
5555

5656
def forward(self, x: torch.Tensor) -> torch.Tensor:
5757
result = super().forward(x)
@@ -112,7 +112,7 @@ def __init__(
112112

113113
init_lora_weights = kwargs.pop("init_lora_weights", True)
114114
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
115-
self.active_adapter = adapter_name
115+
self.set_adapter(adapter_name)
116116

117117
def forward(self, x: torch.Tensor) -> torch.Tensor:
118118
result = super().forward(x)

src/peft/tuners/adalora/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self.weight = quant_linear_module.qweight
3636
init_lora_weights = kwargs.pop("init_lora_weights", True)
3737
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
38-
self.active_adapter = adapter_name
38+
self.set_adapter(adapter_name)
3939

4040
def forward(self, x: torch.Tensor) -> torch.Tensor:
4141
result = self.quant_linear_module(x)

src/peft/tuners/adalora/layer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424

2525

2626
class AdaLoraLayer(LoraLayer):
27+
# List all names of layers that may contain adapter weights
28+
# Note: ranknum doesn't need to be included as it is not an nn.Module
29+
adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"]
30+
2731
def __init__(
2832
self,
2933
in_features: int,
@@ -59,6 +63,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
5963
if init_lora_weights:
6064
self.reset_lora_parameters(adapter_name)
6165
self.to(self.weight.device)
66+
self.set_adapter(self.active_adapters)
6267

6368
def reset_lora_parameters(self, adapter_name):
6469
if adapter_name in self.lora_A.keys():
@@ -92,7 +97,7 @@ def __init__(
9297

9398
nn.Linear.reset_parameters(self)
9499
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
95-
self.active_adapter = adapter_name
100+
self.set_adapter(adapter_name)
96101

97102
def merge(self) -> None:
98103
if self.merged:

src/peft/tuners/adalora/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def _create_and_replace(
141141
# If it is not a LoraLayer, create a new module, else update it with new adapters
142142
if not isinstance(target, AdaLoraLayer):
143143
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
144+
if adapter_name != self.active_adapter:
145+
# adding an additional adapter: it is not automatically trainable
146+
new_module.requires_grad_(False)
144147
self._replace_module(parent, target_name, new_module, target)
145148
else:
146149
target.update_layer(

src/peft/tuners/ia3/bnb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ def __init__(
4040
index=kwargs.get("index", None),
4141
)
4242
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
43+
self.is_feedforward = is_feedforward
4344

4445
# Freezing the pre-trained weight matrix
4546
self.weight.requires_grad = False
4647

4748
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
4849
self.update_layer(adapter_name, init_ia3_weights)
49-
self.active_adapter = adapter_name
50-
self.is_feedforward = is_feedforward
50+
self.set_adapter(adapter_name)
5151

5252
def forward(self, x: torch.Tensor) -> torch.Tensor:
5353
if self.disable_adapters:

src/peft/tuners/ia3/layer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525

2626
class IA3Layer(BaseTunerLayer):
27+
# List all names of layers that may contain adapter weights
28+
adapter_layer_names = ["ia3_l"]
29+
2730
def __init__(
2831
self,
2932
in_features: int,
@@ -34,8 +37,8 @@ def __init__(
3437
self.ia3_l = nn.ParameterDict({})
3538
# Mark the weight as unmerged
3639
self.merged = False
40+
self._disable_adapters = False
3741
self.merged_adapters = []
38-
self.disable_adapters = False
3942
self.in_features = in_features
4043
self.out_features = out_features
4144
self.is_feedforward = is_feedforward
@@ -50,6 +53,7 @@ def update_layer(self, adapter_name, init_ia3_weights):
5053
if init_ia3_weights:
5154
self.reset_ia3_parameters(adapter_name)
5255
self.to(self.weight.device)
56+
self.set_adapter(self.active_adapters)
5357

5458
def reset_ia3_parameters(self, adapter_name):
5559
if adapter_name in self.ia3_l.keys():
@@ -72,6 +76,7 @@ def __init__(
7276

7377
nn.Linear.__init__(self, in_features, out_features, **kwargs)
7478
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
79+
self.is_feedforward = is_feedforward
7580
# Freezing the pre-trained weight matrix
7681
self.weight.requires_grad = False
7782

@@ -81,9 +86,7 @@ def __init__(
8186

8287
nn.Linear.reset_parameters(self)
8388
self.update_layer(adapter_name, init_ia3_weights)
84-
self.active_adapter = adapter_name
85-
86-
self.is_feedforward = is_feedforward
89+
self.set_adapter(adapter_name)
8790

8891
def merge(self) -> None:
8992
if self.merged:

src/peft/tuners/ia3/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def _create_and_replace(
178178
)
179179
else:
180180
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
181+
if adapter_name != self.active_adapter:
182+
# adding an additional adapter: it is not automatically trainable
183+
new_module.requires_grad_(False)
181184
self._replace_module(parent, target_name, new_module, target)
182185

183186
@staticmethod
@@ -213,10 +216,8 @@ def get_peft_config_as_dict(self, inference: bool = False):
213216

214217
def _set_adapter_layers(self, enabled=True):
215218
for module in self.model.modules():
216-
if isinstance(module, IA3Layer):
217-
module.disable_adapters = False if enabled else True
218-
elif isinstance(module, ModulesToSaveWrapper):
219-
module.disable_adapters = False if enabled else True
219+
if isinstance(module, (IA3Layer, ModulesToSaveWrapper)):
220+
module.enable_adapters(enabled)
220221

221222
def enable_adapter_layers(self):
222223
self._set_adapter_layers(enabled=True)
@@ -230,7 +231,7 @@ def set_adapter(self, adapter_name):
230231
if module.merged:
231232
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
232233
module.unmerge()
233-
module.active_adapter = adapter_name
234+
module.set_adapter(adapter_name)
234235

235236
def _prepare_adapter_config(self, peft_config, model_config):
236237
if peft_config.target_modules is None:

src/peft/tuners/lora/bnb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.weight.requires_grad = False
5555
init_lora_weights = kwargs.pop("init_lora_weights", True)
5656
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
57-
self.active_adapter = adapter_name
57+
self.set_adapter(adapter_name)
5858

5959
def merge(self):
6060
if self.merged:
@@ -195,7 +195,7 @@ def __init__(
195195

196196
init_lora_weights = kwargs.pop("init_lora_weights", True)
197197
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
198-
self.active_adapter = adapter_name
198+
self.set_adapter(adapter_name)
199199

200200
def merge(self):
201201
if self.merged:

src/peft/tuners/lora/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
self.weight = quant_linear_module.qweight
3737
init_lora_weights = kwargs.pop("init_lora_weights", True)
3838
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
39-
self.active_adapter = adapter_name
39+
self.set_adapter(adapter_name)
4040

4141
def forward(self, x: torch.Tensor):
4242
# note: logic differs from default Linear because merging is not supported

src/peft/tuners/lora/layer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727

2828
class LoraLayer(BaseTunerLayer):
29+
# List all names of layers that may contain adapter weights
30+
adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"]
31+
2932
def __init__(self, in_features: int, out_features: int, **kwargs):
3033
self.r = {}
3134
self.lora_alpha = {}
@@ -38,8 +41,8 @@ def __init__(self, in_features: int, out_features: int, **kwargs):
3841
self.lora_embedding_B = nn.ParameterDict({})
3942
# Mark the weight as unmerged
4043
self.merged = False
44+
self._disable_adapters = False
4145
self.merged_adapters = []
42-
self.disable_adapters = False
4346
self.in_features = in_features
4447
self.out_features = out_features
4548
self.kwargs = kwargs
@@ -82,6 +85,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
8285
self.to(weight.device, dtype=weight.dtype)
8386
else:
8487
self.to(weight.device)
88+
self.set_adapter(self.active_adapters)
8589

8690
def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
8791
if r <= 0:
@@ -197,8 +201,8 @@ def __init__(
197201
self.fan_in_fan_out = fan_in_fan_out
198202

199203
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
200-
self.active_adapter = adapter_name
201204
self.is_target_conv_1d_layer = is_target_conv_1d_layer
205+
self.set_adapter(adapter_name)
202206

203207
def merge(self) -> None:
204208
if self.merged:
@@ -275,7 +279,7 @@ def __init__(
275279
self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs)
276280
LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)
277281
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
278-
self.active_adapter = adapter_name
282+
self.set_adapter(adapter_name)
279283

280284
def merge(self) -> None:
281285
if self.merged:
@@ -364,7 +368,7 @@ def __init__(
364368
)
365369

366370
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
367-
self.active_adapter = adapter_name
371+
self.set_adapter(adapter_name)
368372

369373
def merge(self) -> None:
370374
if self.merged:

0 commit comments

Comments
 (0)