Skip to content

Commit a1b067e

Browse files
authored
Using the refactored data processing flow in PaddleFormers (#1393)
1 parent b55e437 commit a1b067e

25 files changed

+173
-50
lines changed

docs/datasets.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,9 @@ Here is a multi-image example of SFT VL dataset:
315315
}
316316
```
317317

318-
## chatml Format
318+
## messages Format
319319

320-
The chatml Format is used for training thinking models and function call training:
320+
The messages Format is used for training thinking models and function call training:
321321

322322
Demo data for thinking models:
323323

ernie/dataset/text_sft_reader/finetuning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
sampling_pseudo_examples,
3434
sampling_pseudo_examples_fc,
3535
)
36-
from paddleformers.datasets.data_utils import pad_batch_data, round_up_to_multiple_of_8
36+
from paddleformers.datasets.collate import pad_batch_data
37+
from paddleformers.datasets.data_utils import round_up_to_multiple_of_8
3738

3839
logger = logging.getLogger(__name__)
3940

ernie/fusion_ops/common_fusion_ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@ def _fusion_flash_attention(
7171
"""
7272

7373
if attn_mask_startend_row_indices is not None:
74-
if attn_mask_startend_row_indices.ndim == 3:
75-
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(
76-
-1
77-
)
7874
if use_sparse_flash_attn:
75+
# attn_mask_startend_row_indices.ndim mush be 4
76+
if attn_mask_startend_row_indices.ndim == 3:
77+
attn_mask_startend_row_indices = (
78+
attn_mask_startend_row_indices.unsqueeze(-1)
79+
)
7980
if rr_flash_attn is None:
8081
out = flashmask_attention(
8182
q,
@@ -94,6 +95,11 @@ def _fusion_flash_attention(
9495
causal=True,
9596
)
9697
else:
98+
# attn_mask_startend_row_indices.ndim mush be 3
99+
if attn_mask_startend_row_indices.ndim == 4:
100+
attn_mask_startend_row_indices = attn_mask_startend_row_indices.squeeze(
101+
-1
102+
)
97103
attention_mask = _gen_from_sparse_attn_mask_indices(
98104
attn_mask_startend_row_indices, q.dtype
99105
)

erniekit/eval/eval.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from paddleformers.trainer.trainer_utils import ShardingOption
3737
from paddleformers.utils.log import logger
3838
from paddleformers import __version__ as paddleformers_version
39+
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer
3940

4041
from ernie.configuration import Ernie4_5_MoeConfig
4142
from ernie.modeling_moe import Ernie4_5_MoeForCausalLM
@@ -418,15 +419,39 @@ def run_eval(args: Optional[dict[str, Any]] = None) -> None:
418419
"encode_one_turn": data_args.encode_one_turn,
419420
"use_template": data_args.use_template,
420421
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
422+
"truncate_packing": data_args.truncate_packing,
423+
"stage": model_args.stage,
424+
"is_valid": False,
425+
"template_backend": data_args.template_backend,
426+
"split_multi_turn": data_args.split_multi_turn,
421427
}
422-
from paddleformers.datasets.finetuning import collate_fn
428+
dataset_config.update(
429+
{
430+
"template": data_args.template,
431+
"train_on_prompt": False,
432+
"tool_format": None,
433+
"default_system": None,
434+
"enable_thinking": True,
435+
}
436+
)
437+
438+
if dataset_config["template_backend"] == "custom":
439+
template_instance = get_template_and_fix_tokenizer(dataset_config)
440+
else:
441+
template_instance = None
442+
dataset_config.update(
443+
{
444+
"template_instance": template_instance,
445+
}
446+
)
447+
from paddleformers.datasets.collate import collate_fn
423448

424449
if data_args.dataset_type == "map":
425-
from paddleformers.datasets.finetuning import (
450+
from paddleformers.datasets.loader import (
426451
create_indexed_dataset as create_dataset,
427452
)
428453
else:
429-
from paddleformers.datasets.finetuning import create_dataset
454+
from paddleformers.datasets.loader import create_dataset
430455
dataset_config.update(
431456
{
432457
"num_samples_each_epoch": data_args.num_samples_each_epoch,
@@ -440,11 +465,11 @@ def run_eval(args: Optional[dict[str, Any]] = None) -> None:
440465
eval_file_path = os.path.join(data_args.offline_dataset_path, "eval")
441466
eval_dataset = create_dataset(data_file_prefix=eval_file_path)
442467
else:
468+
dataset_config["is_valid"] = True
443469
eval_dataset = create_dataset(
444470
task_group=data_args.eval_dataset_path,
445471
task_group_prob=data_args.eval_dataset_prob,
446472
sub_dataset_type=data_args.eval_dataset_type,
447-
is_valid=True,
448473
**dataset_config,
449474
)
450475

erniekit/hparams/data_args.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,25 @@ class DataArguments:
146146
default=True,
147147
metadata={"help": "Whether to use cls to predict RM score."},
148148
)
149+
truncate_packing: bool = field(
150+
default=True,
151+
metadata={
152+
"help": "Whether to truncate data in packing (only valid in pretrain online dataflow)."
153+
},
154+
)
155+
template: str = field(
156+
default=None,
157+
metadata={"help": "The chat template used in training."},
158+
)
159+
split_multi_turn: bool = field(
160+
default=False,
161+
metadata={
162+
"help": "Whether to split multi-round dialogues into multiple pieces of data for training"
163+
},
164+
)
165+
template_backend: str = field(
166+
default="jinja",
167+
metadata={
168+
"help": "jinja means using apply_chat_template, custom means using a custom template"
169+
},
170+
)

erniekit/hparams/model_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class ModelArguments:
9595
"help": "Under use attn_mask_startend_row_indices=True, whether use sparse flash attention or not."
9696
},
9797
)
98+
use_global_causal_attn: bool = field(
99+
default=False,
100+
metadata={"help": "Whether to use global causal attention in packing data"},
101+
)
98102
use_sparse_head_and_loss_fn: bool = field(
99103
default=False,
100104
metadata={"help": "Whether to use sparse LM Head and loss function."},

erniekit/train/dpo/dpo_estimate_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# isort: off
2525
# fmt: off
2626
# isort: on
27-
from paddleformers.datasets.dpo import create_dataset
27+
from paddleformers.datasets.loader import create_dataset
2828

2929

3030
def calculate_acc_steps(num_samples, train_batch, dataset_world_size, per_device_train_batch_size):

erniekit/train/dpo/workflow.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@
4646
from paddleformers.trainer.trainer_utils import ShardingOption
4747
from paddleformers.utils.log import logger
4848
from paddleformers import __version__ as paddleformers_version
49+
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer
4950

5051
from ernie.callbacks import LayerwiseDropoutCallback
5152
from ernie.configuration import Ernie4_5_MoeConfig
52-
from paddleformers.datasets.dpo import collate_fn, create_dataset
53+
from paddleformers.datasets.collate import dpo_collate_fn as collate_fn
54+
from paddleformers.datasets.loader import create_dataset
5355
from ernie.modeling_moe import Ernie4_5_MoeForCausalLM
5456
from ernie.modeling_moe_pp import Ernie4_5_MoeForCausalLMPipe
5557
from ernie.tokenizer import Ernie4_5_Tokenizer
@@ -498,7 +500,29 @@ def run_dpo(
498500
"packing": data_args.packing,
499501
"mix_strategy": data_args.mix_strategy,
500502
"encode_one_turn": data_args.encode_one_turn,
503+
"stage": model_args.stage,
504+
"is_valid": False,
505+
"template_backend": data_args.template_backend,
501506
}
507+
dataset_config.update(
508+
{
509+
"template": data_args.template,
510+
"train_on_prompt": False,
511+
"tool_format": None,
512+
"default_system": None,
513+
"enable_thinking": True,
514+
}
515+
)
516+
517+
if dataset_config["template_backend"] == "custom":
518+
template_instance = get_template_and_fix_tokenizer(dataset_config)
519+
else:
520+
template_instance = None
521+
dataset_config.update(
522+
{
523+
"template_instance": template_instance,
524+
}
525+
)
502526

503527
if finetuning_args.max_steps == -1:
504528
if data_args.mix_strategy == "random":
@@ -549,11 +573,11 @@ def run_dpo(
549573
)
550574

551575
if finetuning_args.do_eval and finetuning_args.should_load_dataset:
576+
dataset_config["is_valid"] = True
552577
eval_dataset = create_dataset(
553578
task_group=data_args.eval_dataset_path,
554579
task_group_prob=data_args.eval_dataset_prob,
555580
sub_dataset_type=data_args.eval_dataset_type,
556-
is_valid=True,
557581
**dataset_config,
558582
)
559583
logger.info("Creating dataset successfully ...")

erniekit/train/sft/workflow.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from paddleformers.trainer.trainer_utils import ShardingOption
4949
from paddleformers.transformers.model_utils import unwrap_model
50+
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer
5051
from paddleformers.data.causal_dataset import (
5152
build_train_valid_test_datasets,
5253
check_data_split,
@@ -529,15 +530,41 @@ def run_sft(
529530
"encode_one_turn": data_args.encode_one_turn,
530531
"use_template": data_args.use_template,
531532
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
533+
"truncate_packing": data_args.truncate_packing,
534+
"stage": model_args.stage,
535+
"is_valid": False,
536+
"template_backend": data_args.template_backend,
537+
"split_multi_turn": data_args.split_multi_turn,
532538
}
533-
from paddleformers.datasets.finetuning import collate_fn
539+
540+
dataset_config.update(
541+
{
542+
"template": data_args.template,
543+
"train_on_prompt": False,
544+
"tool_format": None,
545+
"default_system": None,
546+
"enable_thinking": True,
547+
}
548+
)
549+
550+
if dataset_config["template_backend"] == "custom":
551+
template_instance = get_template_and_fix_tokenizer(dataset_config)
552+
else:
553+
template_instance = None
554+
dataset_config.update(
555+
{
556+
"template_instance": template_instance,
557+
}
558+
)
559+
560+
from paddleformers.datasets.collate import collate_fn
534561

535562
if data_args.dataset_type == "map":
536-
from paddleformers.datasets.finetuning import (
563+
from paddleformers.datasets.loader import (
537564
create_indexed_dataset as create_dataset,
538565
)
539566
else:
540-
from paddleformers.datasets.finetuning import create_dataset
567+
from paddleformers.datasets.loader import create_dataset
541568
dataset_config.update(
542569
{
543570
"num_samples_each_epoch": data_args.num_samples_each_epoch,
@@ -570,11 +597,11 @@ def run_sft(
570597
eval_file_path = os.path.join(data_args.offline_dataset_path, "eval")
571598
eval_dataset = create_dataset(data_file_prefix=eval_file_path)
572599
else:
600+
dataset_config["is_valid"] = True
573601
eval_dataset = create_dataset(
574602
task_group=data_args.eval_dataset_path,
575603
task_group_prob=data_args.eval_dataset_prob,
576604
sub_dataset_type=data_args.eval_dataset_type,
577-
is_valid=True,
578605
**dataset_config,
579606
)
580607

examples/configs/ERNIE-4.5-21B-A3B-Thinking/fc/run_fc_8k.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
### data
2-
train_dataset_type: "chatml"
3-
eval_dataset_type: "chatml"
2+
train_dataset_type: "messages"
3+
eval_dataset_type: "messages"
44
train_dataset_path: "./examples/data/function-call-train.jsonl"
55
train_dataset_prob: "1.0"
66
eval_dataset_path: "./examples/data/function-call-eval.jsonl"
77
eval_dataset_prob: "1.0"
88
max_seq_len: 8192
99
num_samples_each_epoch: 6000000
10+
split_multi_turn: True
1011

1112
### model
1213
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking

0 commit comments

Comments
 (0)