Skip to content

[LoRA] add quick_lora#8106

Merged
wawltor merged 12 commits into
PaddlePaddle:developfrom
JunnYu:add_quick_lora
Mar 25, 2024
Merged

[LoRA] add quick_lora#8106
wawltor merged 12 commits into
PaddlePaddle:developfrom
JunnYu:add_quick_lora

Conversation

@JunnYu
Copy link
Copy Markdown
Member

@JunnYu JunnYu commented Mar 13, 2024

PR types

New features

PR changes

APIs

Description

  • 优化lora的前向和反向计算。
  • 非动态padding、短文本的情况下,提速约为3%。

已知缺陷:

  • lora dropout 必须设置为0

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Mar 13, 2024

Thanks for your contribution!

@JunnYu JunnYu requested review from gongel and lugimzzz March 13, 2024 03:14
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 13, 2024

Codecov Report

Attention: Patch coverage is 29.93197% with 103 lines in your changes are missing coverage. Please review.

Project coverage is 55.41%. Comparing base (f005084) to head (749b419).

Files Patch % Lines
paddlenlp/peft/lora/lora_quick_layers.py 28.12% 69 Missing ⚠️
paddlenlp/peft/lora/lora_layers.py 28.88% 32 Missing ⚠️
paddlenlp/peft/lora/lora_config.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8106      +/-   ##
===========================================
- Coverage    55.44%   55.41%   -0.03%     
===========================================
  Files          596      597       +1     
  Lines        91464    91587     +123     
===========================================
+ Hits         50713    50754      +41     
- Misses       40751    40833      +82     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

gongel
gongel previously approved these changes Mar 14, 2024
…t be set to True to prevent any potential errors from occurring.
@gongel
Copy link
Copy Markdown
Member

gongel commented Mar 19, 2024

冲突了

Comment thread paddlenlp/peft/lora/lora_quick_layers.py Outdated
input_grad = None

if not input.stop_gradient:
input_grad = paddle.addmm(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

求input_grad是不是可以考虑使用merged_weight,input_grad= paddle.matmul(grad_output, merged_weight, transpose_y=True)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged_weight这个东西没法从前向复用,复用会占用很大的显存。
然后如果合并计算的话,就无法复用 lora_B_input_grad = paddle.matmul(grad_output, lora_B, transpose_y=True)。需要重新计算一次

@JunnYu JunnYu requested review from gongel and lugimzzz March 21, 2024 09:03
Copy link
Copy Markdown
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@wawltor wawltor merged commit d577e19 into PaddlePaddle:develop Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants