Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Sep 5, 2023

This is an alternative to #900, resolves #899.

Thanks @passaglia for figuring out the underlying issue.

Description

Currently, we don't handle setting requires_grad on adapter layers really well. The main issue is that it can be set to True on adapter parameters that are not being used, e.g. the original_module in ModulesToSaveWrapper or inactive adapters in LoRA.

Normally, this is not a big issue, except maybe if we want to correctly count the number of trainable parameters. However, when training with DistributedDataParallel, this results in errors, as PyTorch thinks that all parameters with requires_grad=True should participate in the loss computation, but those mentioned parameters don't. For that reason, training with DDP currently errors when using modules_to_save or multiple adapters.

Implementation

This turned out to be more complicated than I initially thought. The logic for setting requires_grad is all over the place, it was hard to encapsulate the logic and I only succeeded partially. As is, this PR is more complex than the one it tries to supersede, #900, but it is also "more correct".

Tests were added to check whether requires_grad is set correctly. There are (so far) no tests for whether DDP indeed works, they could be added with multi-GPU. I did, however, test an early stage of this PR with DDP and setting requires_grad correctly will indeed fix the DDP error.

DONE/TODO

  • ModulesToSaveWrapper
  • LoRA
  • IA³
  • AdaLora

This is an alternative to huggingface#900, resolves huggingface#899.

Description

Currently, we don't handle setting requires_grad on adapter layers
really well. The main issue is that it can be set to True on adapter
parameters that are not being used, e.g. the original_module in
ModulesToSaveWrapper or inactive adapters in LoRA.

Normally, this is not a big issue, except maybe if we want to correctly
count the number of trainable parameters. However, when training with
DistributedDataParallel, this results in errors, as PyTorch thinks that
all parameters with requires_grad=True should participate in the loss
computation, but those mentioned parameters don't. For that reason,
training with DDP currently fails when using modules_to_save or multiple
adapters.

Implementation

This turned out to be more complicated than I initially thought. The
logic for setting requires_grad is all over the place, it was hard to
encapsulate the logic and I only succeeded partially. As is, this PR is
more complex than the one it tries to supersede, huggingface#900, but it is also
"more correct".

Tests were added to check whether requires_grad is set correctly. There
are (so far) no tests for whether DDP indeed works, they could be added
with multi-GPU. I did, however, test an early stage of this PR with DDP
and setting requires_grad correctly will indeed fix the DDP error.

DONE/TODO

- [x] ModulesToSaveWrapper
- [x] LoRA
- [ ] IA³
- [ ] AdaLora

Since some tuners are not implemented yet, tests are expected to fail.
Check the new tests at the bottom of test_custom.py, those should pass.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review September 12, 2023 15:16
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for fixing this major bug when using DDP/Multiple Adapters with PEFT. LGTM! 🤗

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mile @BenjaminBossan !

@pacman100 pacman100 merged commit 634bd19 into huggingface:main Sep 26, 2023
@BenjaminBossan BenjaminBossan deleted the fix-setting-requires-grad branch September 26, 2023 07:58
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
* [WIP] Fix setting requires_grad on adapter layers

This is an alternative to huggingface#900, resolves huggingface#899.

Description

Currently, we don't handle setting requires_grad on adapter layers
really well. The main issue is that it can be set to True on adapter
parameters that are not being used, e.g. the original_module in
ModulesToSaveWrapper or inactive adapters in LoRA.

Normally, this is not a big issue, except maybe if we want to correctly
count the number of trainable parameters. However, when training with
DistributedDataParallel, this results in errors, as PyTorch thinks that
all parameters with requires_grad=True should participate in the loss
computation, but those mentioned parameters don't. For that reason,
training with DDP currently fails when using modules_to_save or multiple
adapters.

Implementation

This turned out to be more complicated than I initially thought. The
logic for setting requires_grad is all over the place, it was hard to
encapsulate the logic and I only succeeded partially. As is, this PR is
more complex than the one it tries to supersede, huggingface#900, but it is also
"more correct".

Tests were added to check whether requires_grad is set correctly. There
are (so far) no tests for whether DDP indeed works, they could be added
with multi-GPU. I did, however, test an early stage of this PR with DDP
and setting requires_grad correctly will indeed fix the DDP error.

DONE/TODO

- [x] ModulesToSaveWrapper
- [x] LoRA
- [ ] IA³
- [ ] AdaLora

Since some tuners are not implemented yet, tests are expected to fail.
Check the new tests at the bottom of test_custom.py, those should pass.

* Refactor: move more requires_grad machinery to ABC

* [skip ci] [WIP] Add requires_grad logic to IA³

* Add AdaLora

* Fix some minor issues

* Make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Modules_to_save can't be used with DDP and gradient checkpointing

4 participants