Skip to content

Commit 9350255

Browse files
committed
[LLM] add llama1-13b pretrain
[LLM] llama1-7b pretrain with callback
1 parent 45c4220 commit 9350255

24 files changed

Lines changed: 580 additions & 616 deletions

File tree

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .base import Driver
2-
from .callback_paddle import PaddleCallback
32
from .event import Event
43
from .log_event import LogEventManager

training/benchmarks/driver/callback_paddle.py

Lines changed: 0 additions & 92 deletions
This file was deleted.
Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,24 @@
11
import os
22
from contextlib import contextmanager
3-
import random
4-
import numpy as np
3+
54
import paddle
65
import paddle.distributed as dist
7-
6+
from paddlenlp.trainer import (
7+
TrainerCallback,
8+
TrainerControl,
9+
TrainerState,
10+
TrainingArguments,
11+
)
12+
from paddlenlp.trainer.trainer_utils import IntervalStrategy
13+
14+
from .base import Driver
15+
from .event import Event
16+
from typing import Dict
817

918
def barrier():
1019
if dist.is_initialized():
1120
dist.barrier()
1221

13-
def set_seed(args):
14-
if args.device == "cpu":
15-
idx = 0
16-
else:
17-
idx = paddle.distributed.get_rank()
18-
random.seed(args.seed + idx)
19-
np.random.seed(args.seed + idx)
20-
paddle.seed(args.seed + idx)
21-
22-
23-
def get_rank(default=0):
24-
"""
25-
Gets distributed rank or returns zero if distributed is not initialized.
26-
"""
27-
if dist.is_initialized():
28-
rank = dist.get_rank()
29-
else:
30-
rank = default
31-
return rank
32-
33-
34-
def get_world_size():
35-
"""
36-
Gets total number of distributed workers or returns one if distributed is
37-
not initialized.
38-
"""
39-
if dist.is_initialized():
40-
world_size = dist.get_world_size()
41-
else:
42-
world_size = 1
43-
return world_size
44-
45-
46-
def main_proc_print(*args, **kwargs):
47-
if is_main_process():
48-
print(*args, **kwargs)
49-
50-
51-
def init_dist_training_env(config):
52-
if dist.get_world_size() <= 1:
53-
config.device = paddle.device.get_device()
54-
config.world_size = get_world_size()
55-
else:
56-
dist.init_parallel_env()
57-
config.device = paddle.device.get_device()
58-
config.world_size = get_world_size()
59-
print("------------------------")
60-
print("device numbers:", config.world_size)
61-
print("the processing uses", config.device)
62-
return
63-
64-
65-
def global_batch_size(config):
66-
67-
return config.per_device_train_batch_size * config.world_size
68-
69-
70-
@contextmanager
71-
def sync_workers():
72-
"""
73-
Yields distributed rank and synchronizes all workers on exit.
74-
"""
75-
rank = get_rank()
76-
yield rank
77-
barrier()
78-
79-
8022
def is_main_process():
8123
if dist.is_initialized():
8224
if "PADDLE_TRAINER_ID" in os.environ:
@@ -86,15 +28,83 @@ def is_main_process():
8628

8729
return True
8830

89-
90-
def format_step(step):
91-
if isinstance(step, str):
92-
return step
93-
s = ""
94-
if len(step) > 0:
95-
s += "Training Epoch: {} ".format(step[0])
96-
if len(step) > 1:
97-
s += "Training Iteration: {} ".format(step[1])
98-
if len(step) > 2:
99-
s += "Validation Iteration: {} ".format(step[2])
100-
return s
31+
class PaddleCallback(TrainerCallback):
32+
def __init__(self, driver: Driver):
33+
self.driver = driver
34+
35+
def on_init_end(
36+
self,
37+
args: TrainingArguments,
38+
state: TrainerState,
39+
control: TrainerState,
40+
**kwargs
41+
):
42+
self.driver.event(Event.INIT_END)
43+
44+
def on_train_begin(
45+
self,
46+
args: TrainingArguments,
47+
state: TrainerState,
48+
control: TrainerControl,
49+
**kwargs
50+
):
51+
self.driver.event(Event.TRAIN_START)
52+
53+
def on_train_end(
54+
self,
55+
args: TrainingArguments,
56+
state: TrainerState,
57+
control: TrainerControl,
58+
**kwargs
59+
):
60+
self.driver.event(Event.TRAIN_END)
61+
62+
def on_epoch_begin(
63+
self,
64+
args: TrainingArguments,
65+
state: TrainerState,
66+
control: TrainerControl,
67+
**kwargs
68+
):
69+
self.driver.event(Event.EPOCH_BEGIN, epoch=state.epoch)
70+
71+
def on_epoch_end(
72+
self,
73+
args: TrainingArguments,
74+
state: TrainerState,
75+
control: TrainerControl,
76+
**kwargs
77+
):
78+
self.driver.event(Event.EPOCH_END, epoch=state.epoch)
79+
80+
def on_step_begin(
81+
self,
82+
args: TrainingArguments,
83+
state: TrainerState,
84+
control: TrainerControl,
85+
**kwargs
86+
):
87+
self.driver.event(Event.STEP_BEGIN, step=state.global_step + 1)
88+
89+
def on_evaluate(
90+
self,
91+
args: TrainingArguments,
92+
state: TrainerState,
93+
control: TrainerControl,
94+
**kwargs
95+
):
96+
logs = kwargs["metrics"]
97+
logs["global_step"] = state.global_step
98+
self.driver.event(Event.EVALUATE, result=logs)
99+
100+
def on_log(
101+
self,
102+
args: TrainingArguments,
103+
state: TrainerState,
104+
control: TrainerControl,
105+
logs=None,
106+
**kwargs
107+
):
108+
_ = logs.pop("total_flos", None)
109+
if state.is_local_process_zero:
110+
self.driver.logger.log(Event.STEP_END, message=logs)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
### 模型信息
2+
#### 模型介绍
3+
We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions
4+
of tokens, and show that it is possible to train
5+
state-of-the-art models using publicly available datasets exclusively, without resorting
6+
to proprietary and inaccessible datasets. In
7+
particular, LLaMA-13B outperforms GPT-3
8+
(175B) on most benchmarks, and LLaMA65B is competitive with the best models,
9+
Chinchilla-70B and PaLM-540B. We release
10+
all our models to the research community1
11+
.
12+
13+
Please refer to this paper for a detailed description of LLaMA1:
14+
[LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
15+
16+
#### 模型代码来源
17+
Paddle case代码来源:
18+
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/llama licensed under the Apache License, Version 2.0.
19+
20+
21+
#### 数据集
22+
##### 测试数据集下载地址
23+
测试数据集中提供了处理好的100k条doc的训练样本:
24+
```
25+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
26+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
27+
```
28+
29+
##### 预处理
30+
> 无需预处理
31+
32+
#### 模型实现
33+
* 运行自动加载
34+
35+
#### 模型checkpoint
36+
* 运行自动下载,参数量:13B
37+
* Paddle的 LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)
38+
39+
### 框架与芯片支持情况
40+
| | Pytorch |Paddle|TensorFlow2|
41+
| ---- | ---- | ---- | ---- |
42+
| Nvidia GPU |N/A ||N/A|
43+
| | | | |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/ssd2/laixinyi/projects/FlagPerf/training/benchmarks/llama1_7B/paddle
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
### 模型信息
2+
#### 模型介绍
3+
We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions
4+
of tokens, and show that it is possible to train
5+
state-of-the-art models using publicly available datasets exclusively, without resorting
6+
to proprietary and inaccessible datasets. In
7+
particular, LLaMA-13B outperforms GPT-3
8+
(175B) on most benchmarks, and LLaMA65B is competitive with the best models,
9+
Chinchilla-70B and PaLM-540B. We release
10+
all our models to the research community1
11+
.
12+
13+
Please refer to this paper for a detailed description of LLaMA1:
14+
[LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
15+
16+
#### 模型代码来源
17+
Paddle case代码来源:
18+
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/llama licensed under the Apache License, Version 2.0.
19+
20+
21+
#### 数据集
22+
##### 测试数据集下载地址
23+
测试数据集中提供了处理好的100k条doc的训练样本:
24+
```
25+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
26+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
27+
```
28+
29+
##### 预处理
30+
> 无需预处理
31+
32+
#### 模型实现
33+
* 运行自动加载
34+
35+
#### 模型checkpoint
36+
* 运行自动下载,参数量:7B
37+
* Paddle的 LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)
38+
39+
### 框架与芯片支持情况
40+
| | Pytorch |Paddle|TensorFlow2|
41+
| ---- | ---- | ---- | ---- |
42+
| Nvidia GPU |N/A ||N/A|
43+
| | | | |

0 commit comments

Comments
 (0)