Skip to content

[llm]support dpo pp#9039

Merged
ZHUI merged 10 commits into
PaddlePaddle:developfrom
lugimzzz:dpo
Sep 20, 2024
Merged

[llm]support dpo pp#9039
ZHUI merged 10 commits into
PaddlePaddle:developfrom
lugimzzz:dpo

Conversation

@lugimzzz
Copy link
Copy Markdown
Contributor

@lugimzzz lugimzzz commented Aug 28, 2024

PR types

New features

PR changes

APIs

Description

  1. 重构DPOTrainer与原版逐位对齐loss、metric
  2. 重构DPOTrainer能够支持pp & vpp
  3. 支持LoRA和多种DPO变体(包含KTO、DPO、ORPO、Simpo)
  4. 新增支持多个开源模型

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Aug 28, 2024

Thanks for your contribution!

@lugimzzz lugimzzz changed the title [llm]support dpo/kto pp WIP [llm]support dpo/kto pp Aug 28, 2024
@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 28, 2024

Codecov Report

Attention: Patch coverage is 9.36281% with 697 lines in your changes missing coverage. Please review.

Project coverage is 53.07%. Comparing base (90cef20) to head (69ad7cf).
Report is 244 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/tensor_parallel_utils.py 5.55% 255 Missing ⚠️
paddlenlp/trl/dpo_trainer.py 5.72% 214 Missing ⚠️
paddlenlp/trl/dpo_criterion.py 9.09% 140 Missing ⚠️
paddlenlp/transformers/sequence_parallel_utils.py 17.50% 66 Missing ⚠️
paddlenlp/utils/infohub.py 26.66% 22 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9039      +/-   ##
===========================================
- Coverage    53.26%   53.07%   -0.20%     
===========================================
  Files          652      656       +4     
  Lines       105607   106095     +488     
===========================================
+ Hits         56254    56309      +55     
- Misses       49353    49786     +433     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment thread paddlenlp/transformers/sequence_parallel_utils.py
@lugimzzz lugimzzz changed the title WIP [llm]support dpo/kto pp WIP [llm]support dpo pp Aug 30, 2024
Comment thread paddlenlp/trl/dpo_criterion.py Outdated
Comment thread paddlenlp/trl/dpo_criterion.py
Comment thread paddlenlp/trl/dpo_criterion.py Outdated
Comment thread paddlenlp/trl/dpo_criterion.py
Comment thread paddlenlp/trl/dpo_criterion.py
Comment thread paddlenlp/trl/dpo_trainer.py
Comment thread paddlenlp/trl/dpo_trainer.py
Comment on lines +279 to +301
for key in batch.keys():
if key not in "response_indexs":
concatenated_inputs[key] = [
batch[key][i * per_device_train_batch_size : (i + 1) * per_device_train_batch_size]
for i in range(gradient_accumulation_steps)
]
else:
concatenated_inputs["response_indexs"] = [[] for _ in range(gradient_accumulation_steps)]
for i in range(gradient_accumulation_steps):
for response_index in batch[key]:
if response_index[0] in list(
range(i * per_device_train_batch_size, (i + 1) * per_device_train_batch_size)
):
response_index[0] -= i * per_device_train_batch_size
concatenated_inputs["response_indexs"][i].append(response_index)
concatenated_inputs["response_indexs"][i] = paddle.stack(concatenated_inputs["response_indexs"][i])
if model._layers.config.use_sparse_head_and_loss_fn:
last_batch_response_length = concatenated_inputs["response_indexs"][i][0, 1]
concatenated_inputs["response_indexs"][i][:, 1:] -= last_batch_response_length

concatenated_inputs["reference_chosen_logps"] = None
concatenated_inputs["reference_rejected_logps"] = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议这一大堆,封装成一个函数。

Comment thread paddlenlp/trl/dpo_trainer.py Outdated
Comment thread paddlenlp/trl/dpo_trainer.py
@lugimzzz lugimzzz changed the title WIP [llm]support dpo pp [llm]support dpo pp Sep 13, 2024
Comment thread llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json
Comment thread llm/run_finetune.py
Comment thread paddlenlp/transformers/gemma/modeling.py
Comment thread paddlenlp/transformers/yuan/configuration.py
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Sep 20, 2024

CLA assistant check
All committers have signed the CLA.

get_last_checkpoint,
set_seed,
from dpo_argument import (
DPOConfig,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DPOConfig,DPOTrainingArguments 这些看是否加到主repo?

Copy link
Copy Markdown
Contributor

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit bc55104 into PaddlePaddle:develop Sep 20, 2024
@lugimzzz lugimzzz deleted the dpo branch December 16, 2024 11:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants