Skip to content

add Moslora#9331

Merged
lugimzzz merged 5 commits into
PaddlePaddle:developfrom
TranscenderNing:moslora
Nov 22, 2024
Merged

add Moslora#9331
lugimzzz merged 5 commits into
PaddlePaddle:developfrom
TranscenderNing:moslora

Conversation

@TranscenderNing
Copy link
Copy Markdown
Contributor

PR types

New features

PR changes

add moslora at peft/lora

Description

add moslora method

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Oct 29, 2024

Thanks for your contribution!

@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 53.10%. Comparing base (6813e40) to head (89fc2cc).
Report is 223 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9331      +/-   ##
===========================================
+ Coverage    52.91%   53.10%   +0.19%     
===========================================
  Files          679      685       +6     
  Lines       108433   108855     +422     
===========================================
+ Hits         57378    57810     +432     
+ Misses       51055    51045      -10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment thread llm/tools/merge_lora_params.py Outdated
return

weight = state_dict.pop(name + ".weight")
lora_use_mixer = (lora_state_dict is not None and name + ".lora_AB" in lora_state_dict) or (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议lora_use_mixer直接从lora_config读取不要用state dict的key判断

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resoved

Comment thread llm/tools/merge_lora_params.py Outdated
lora_AB = lora_AB.astype("float32")
out = (weight + lora_A @ lora_AB @ lora_B * scaling).astype("bfloat16")
else:
out = (weight + lora_A @ lora_B * scaling).astype("bfloat16")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里astype改成lora_config里的dtype,原本写的有问题

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

Comment thread llm/utils/argument.py
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
lora_use_mixer: bool = field(
default=False, metadata={"help": "Whether to use MosLoRA: https://arxiv.org/pdf/2406.11909"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

顺带更新一下文档

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

pissa=lora_config.pissa,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
lora_use_mixer=lora_config.lora_use_mixer,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在loramodel的init里if (tensor_parallel_degree >1 or pipeline_parallel_degree > 1) and lora_config.lora_use_mixer: raise NotImplementError("xxx")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

Copy link
Copy Markdown
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lugimzzz lugimzzz merged commit 183e012 into PaddlePaddle:develop Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants