Skip to content

Commit c8ebb53

Browse files
authored
Merge pull request PaddlePaddle#17 from LiuChiachi/improve-data-augment-for-distill-lstm
Improve data augmentation for distilling Bi-LSTM
2 parents cb6e992 + f0d7e16 commit c8ebb53

File tree

10 files changed

+261
-197
lines changed

10 files changed

+261
-197
lines changed

examples/distill/distill_lstm/README.md renamed to examples/model_compression/distill_lstm/README.md

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,24 @@
1616

1717
在模型蒸馏中,较大的模型(在本例中是BERT)通常被称为教师模型,较小的模型(在本例中是Bi-LSTM)通常被成为学生模型。知识的蒸馏通常是通过模型学习蒸馏相关的损失函数实现,在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。
1818

19-
[论文](https://arxiv.org/abs/1903.12136)的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:1.Masking,即以一定的概率将原数据中的word token替换成`[MASK]`;2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
19+
[论文](https://arxiv.org/abs/1903.12136)的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:
20+
21+
1. Masking,即以一定的概率将原数据中的word token替换成`[MASK]`
22+
23+
2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;
24+
25+
3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
26+
在英文数据集任务上,本文使用了Google News语料[预训练的Word Embedding](https://code.google.com/archive/p/word2vec/)初始化小模型的Embedding层。
2027

2128
本实验分为三个训练过程:在特定任务上对BERT的fine-tuning、在特定任务上对基于Bi-LSTM的小模型的训练(用于评价蒸馏效果)、将BERT模型的知识蒸馏到基于Bi-LSTM的小模型上。
2229

2330
## 环境要求
2431
运行本目录下的范例模型需要安装PaddlePaddle 2.0及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
32+
另外,本项目还依赖paddlenlp,可以使用下面的命令进行安装:
2533

26-
另外,本文下载并在对英文数据集的训练中使用了Google News语料[预训练的Word Embedding](https://code.google.com/archive/p/word2vec/)初始化小模型的Embedding层,并使用gensim包对该Word Embedding文件进行读取。因此,运行本实验还需要安装`gensim`及下载预训练的Word Embedding。
27-
34+
```shell
35+
pip install paddlenlp==2.0.0rc
36+
```
2837

2938
## 数据、预训练模型介绍及获取
3039

@@ -61,11 +70,11 @@ python -u ./run_bert_finetune.py \
6170
--num_train_epochs 3 \
6271
--logging_steps 10 \
6372
--save_steps 10 \
64-
--output_dir ../distill/ditill_lstm/model/$TASK_NAME/ \
73+
--output_dir ../model_compression/distill_lstm/pretrained_modelss/$TASK_NAME/ \
6574
--n_gpu 1 \
6675

6776
```
68-
训练完成之后,可将训练效果最好的模型保存在本项目下的`models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。
77+
训练完成之后,可将训练效果最好的模型保存在本项目下的`pretrained_models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。
6978

7079

7180
### 训练小模型
@@ -83,8 +92,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
8392
--optimizer adam \
8493
--lr 3e-4 \
8594
--dropout_prob 0.2 \
86-
--use_pretrained_emb False \
87-
--vocab_path senta_word_dict_subset.txt
95+
--vocab_path senta_word_dict_subset.txt \
96+
--output_dir small_models/senta/
97+
8898
```
8999

90100
```shell
@@ -95,7 +105,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
95105
--batch_size 64 \
96106
--lr 1.0 \
97107
--dropout_prob 0.4 \
98-
--use_pretrained_emb True
108+
--output_dir small_models/SST-2 \
109+
--embedding_name w2v.google_news.target.word-word.dim300.en
110+
99111
```
100112

101113
```shell
@@ -106,7 +118,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
106118
--batch_size 256 \
107119
--lr 2.0 \
108120
--dropout_prob 0.4 \
109-
--use_pretrained_emb True
121+
--output_dir small_models/QQP \
122+
--embedding_name w2v.google_news.target.word-word.dim300.en
123+
110124
```
111125

112126
### 蒸馏模型
@@ -121,9 +135,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
121135
--dropout_prob 0.1 \
122136
--batch_size 64 \
123137
--model_name bert-wwm-ext-chinese \
124-
--use_pretrained_emb False \
125-
--teacher_path model/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
126-
--vocab_path senta_word_dict_subset.txt
138+
--teacher_path pretrained_models/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
139+
--vocab_path senta_word_dict_subset.txt \
140+
--output_dir distilled_models/senta
141+
127142
```
128143

129144
```shell
@@ -136,8 +151,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
136151
--dropout_prob 0.2 \
137152
--batch_size 128 \
138153
--model_name bert-base-uncased \
139-
--use_pretrained_emb True \
140-
--teacher_path model/SST-2/best_model_610/model_state.pdparams
154+
--embedding_name w2v.google_news.target.word-word.dim300.en \
155+
--output_dir distilled_models/SST-2 \
156+
--teacher_path pretrained_models/SST-2/best_model_610/model_state.pdparams
157+
141158
```
142159

143160
```shell
@@ -149,24 +166,25 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
149166
--dropout_prob 0.2 \
150167
--batch_size 256 \
151168
--model_name bert-base-uncased \
152-
--use_pretrained_emb True \
169+
--embedding_name w2v.google_news.target.word-word.dim300.en \
153170
--n_iter 10 \
154-
--teacher_path model/QQP/best_model_17000/model_state.pdparams
171+
--output_dir distilled_models/QQP \
172+
--teacher_path pretrained_models/QQP/best_model_17000/model_state.pdparams
173+
155174
```
156175

157176
各参数的具体说明请参阅 `args.py` ,注意在训练不同任务时,需要调整对应的超参数。
158177

159178

160179
## 蒸馏实验结果
161-
本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.2%、1.8%、1.4%的提升。
162-
163-
| Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
164-
| -------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
165-
| teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
166-
| Teacher | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
167-
| Student | 0.853211 | 0.856171(acc)/0.806057(f1) | 0.920833 | 0.920800 |
168-
| Distilled | 0.885321 | 0.874375(acc)/0.829581(f1) | 0.930000 | 0.935000 |
169-
180+
本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.3%、1.9%、1.4%的提升。
181+
182+
| Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
183+
| ----------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
184+
| Teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
185+
| BERT-base | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
186+
| Bi-LSTM | 0.854358 | 0.856616(acc)/0.799682(f1) | 0.920000 | 0.920000 |
187+
| Distilled Bi-LSTM | 0.887615 | 0.875216(acc)/0.831254(f1) | 0.932500 | 0.934167 |
170188

171189
## 参考文献
172190

examples/distill/distill_lstm/args.py renamed to examples/model_compression/distill_lstm/args.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ def parse_args():
3333
parser.add_argument(
3434
"--num_layers", type=int, default=1, help="Layers number of LSTM.")
3535

36-
parser.add_argument(
37-
'--use_pretrained_emb',
38-
type=eval,
39-
default=False,
40-
help='Whether to use pre-trained embedding tensor.')
41-
4236
parser.add_argument(
4337
"--emb_dim", type=int, default=300, help="Embedding dim.")
4438

@@ -84,6 +78,12 @@ def parse_args():
8478
default=10,
8579
help="The frequency to print evaluation logs.")
8680

81+
parser.add_argument(
82+
"--save_steps",
83+
type=int,
84+
default=100,
85+
help="The frequency to print evaluation logs.")
86+
8787
parser.add_argument(
8888
"--padding_idx",
8989
type=int,
@@ -106,6 +106,24 @@ def parse_args():
106106
default='/root/.paddlenlp/models/bert-base-uncased/bert-base-uncased-vocab.txt',
107107
help="Student model's vocab path.")
108108

109+
parser.add_argument(
110+
"--output_dir",
111+
type=str,
112+
default='models',
113+
help="Directory to save models .")
114+
115+
parser.add_argument(
116+
"--whole_word_mask",
117+
action="store_true",
118+
help="If True, use whole word masking method in data augmentation in distilling."
119+
)
120+
121+
parser.add_argument(
122+
"--embedding_name",
123+
type=str,
124+
default=None,
125+
help="The name of pretrained word embedding.")
126+
109127
parser.add_argument(
110128
"--vocab_size",
111129
type=int,
@@ -118,5 +136,12 @@ def parse_args():
118136
default=0.0,
119137
help="Weight balance between cross entropy loss and mean square loss.")
120138

139+
parser.add_argument(
140+
"--seed",
141+
type=int,
142+
default=2021,
143+
help="Random seed for model parameter initialization, data augmentation and so on."
144+
)
145+
121146
args = parser.parse_args()
122147
return args

examples/distill/distill_lstm/bert_distill.py renamed to examples/model_compression/distill_lstm/bert_distill.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import time
1617

1718
import paddle
1819
import paddle.nn as nn
19-
from paddle.metric import Metric, Accuracy, Precision, Recall
20+
from paddle.metric import Accuracy
2021

21-
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
22-
from paddlenlp.transformers.tokenizer_utils import whitespace_tokenize
22+
from paddlenlp.transformers import BertForSequenceClassification
2323
from paddlenlp.metrics import AccuracyAndF1
2424
from paddlenlp.datasets import GlueSST2, GlueQQP, ChnSentiCorp
2525

2626
from args import parse_args
2727
from small import BiLSTM
28-
from data import create_distill_loader, load_embedding
28+
from data import create_distill_loader
2929

3030
TASK_CLASSES = {
3131
"sst-2": (GlueSST2, Accuracy),
@@ -36,7 +36,6 @@
3636

3737
class TeacherModel(object):
3838
def __init__(self, model_name, param_path):
39-
self.tokenizer = BertTokenizer.from_pretrained(model_name)
4039
self.model = BertForSequenceClassification.from_pretrained(model_name)
4140
self.model.set_state_dict(paddle.load(param_path))
4241
self.model.eval()
@@ -78,14 +77,14 @@ def do_train(agrs):
7877
vocab_path=args.vocab_path,
7978
batch_size=args.batch_size,
8079
max_seq_length=args.max_seq_length,
81-
n_iter=args.n_iter)
82-
83-
emb_tensor = load_embedding(
84-
args.vocab_path) if args.use_pretrained_emb else None
80+
n_iter=args.n_iter,
81+
whole_word_mask=args.whole_word_mask,
82+
seed=args.seed)
8583

8684
model = BiLSTM(args.emb_dim, args.hidden_size, args.vocab_size,
87-
args.output_dim, args.padding_idx, args.num_layers,
88-
args.dropout_prob, args.init_scale, emb_tensor)
85+
args.output_dim, args.vocab_path, args.padding_idx,
86+
args.num_layers, args.dropout_prob, args.init_scale,
87+
args.embedding_name)
8988

9089
if args.optimizer == 'adadelta':
9190
optimizer = paddle.optimizer.Adadelta(
@@ -143,12 +142,21 @@ def do_train(agrs):
143142
acc = evaluate(args.task_name, model, metric, dev_data_loader)
144143
print("eval done total : %s s" % (time.time() - tic_eval))
145144
tic_train = time.time()
145+
146+
if i % args.save_steps == 0:
147+
paddle.save(
148+
model.state_dict(),
149+
os.path.join(args.output_dir,
150+
"step_" + str(global_step) + ".pdparams"))
151+
paddle.save(optimizer.state_dict(),
152+
os.path.join(args.output_dir,
153+
"step_" + str(global_step) + ".pdopt"))
154+
146155
global_step += 1
147156

148157

149158
if __name__ == '__main__':
150-
paddle.seed(2021)
151159
args = parse_args()
152160
print(args)
153-
161+
paddle.seed(args.seed)
154162
do_train(args)

0 commit comments

Comments
 (0)