1616
1717在模型蒸馏中,较大的模型(在本例中是BERT)通常被称为教师模型,较小的模型(在本例中是Bi-LSTM)通常被成为学生模型。知识的蒸馏通常是通过模型学习蒸馏相关的损失函数实现,在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。
1818
19- 在[ 论文] ( https://arxiv.org/abs/1903.12136 ) 的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:1.Masking,即以一定的概率将原数据中的word token替换成` [MASK] ` ;2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
19+ 在[ 论文] ( https://arxiv.org/abs/1903.12136 ) 的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:
20+
21+ 1 . Masking,即以一定的概率将原数据中的word token替换成` [MASK] ` ;
22+
23+ 2 . POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;
24+
25+ 3 . n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
26+ 在英文数据集任务上,本文使用了Google News语料[ 预训练的Word Embedding] ( https://code.google.com/archive/p/word2vec/ ) 初始化小模型的Embedding层。
2027
2128本实验分为三个训练过程:在特定任务上对BERT的fine-tuning、在特定任务上对基于Bi-LSTM的小模型的训练(用于评价蒸馏效果)、将BERT模型的知识蒸馏到基于Bi-LSTM的小模型上。
2229
2330## 环境要求
2431运行本目录下的范例模型需要安装PaddlePaddle 2.0及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[ 安装文档] ( https://www.paddlepaddle.org.cn/#quick-start ) 中的说明更新 PaddlePaddle 安装版本。
32+ 另外,本项目还依赖paddlenlp,可以使用下面的命令进行安装:
2533
26- 另外,本文下载并在对英文数据集的训练中使用了Google News语料[ 预训练的Word Embedding] ( https://code.google.com/archive/p/word2vec/ ) 初始化小模型的Embedding层,并使用gensim包对该Word Embedding文件进行读取。因此,运行本实验还需要安装` gensim ` 及下载预训练的Word Embedding。
27-
34+ ``` shell
35+ pip install paddlenlp==2.0.0rc
36+ ```
2837
2938## 数据、预训练模型介绍及获取
3039
@@ -61,11 +70,11 @@ python -u ./run_bert_finetune.py \
6170 --num_train_epochs 3 \
6271 --logging_steps 10 \
6372 --save_steps 10 \
64- --output_dir ../distill/ditill_lstm/model /$TASK_NAME / \
73+ --output_dir ../model_compression/distill_lstm/pretrained_modelss /$TASK_NAME / \
6574 --n_gpu 1 \
6675
6776```
68- 训练完成之后,可将训练效果最好的模型保存在本项目下的` models /$TASK_NAME/` 下。模型目录下有` model_config.json ` , ` model_state.pdparams ` , ` tokenizer_config.json ` 及` vocab.txt ` 这几个文件。
77+ 训练完成之后,可将训练效果最好的模型保存在本项目下的` pretrained_models /$TASK_NAME/` 下。模型目录下有` model_config.json ` , ` model_state.pdparams ` , ` tokenizer_config.json ` 及` vocab.txt ` 这几个文件。
6978
7079
7180### 训练小模型
@@ -83,8 +92,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
8392 --optimizer adam \
8493 --lr 3e-4 \
8594 --dropout_prob 0.2 \
86- --use_pretrained_emb False \
87- --vocab_path senta_word_dict_subset.txt
95+ --vocab_path senta_word_dict_subset.txt \
96+ --output_dir small_models/senta/
97+
8898```
8999
90100``` shell
@@ -95,7 +105,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
95105 --batch_size 64 \
96106 --lr 1.0 \
97107 --dropout_prob 0.4 \
98- --use_pretrained_emb True
108+ --output_dir small_models/SST-2 \
109+ --embedding_name w2v.google_news.target.word-word.dim300.en
110+
99111```
100112
101113``` shell
@@ -106,7 +118,9 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
106118 --batch_size 256 \
107119 --lr 2.0 \
108120 --dropout_prob 0.4 \
109- --use_pretrained_emb True
121+ --output_dir small_models/QQP \
122+ --embedding_name w2v.google_news.target.word-word.dim300.en
123+
110124```
111125
112126### 蒸馏模型
@@ -121,9 +135,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
121135 --dropout_prob 0.1 \
122136 --batch_size 64 \
123137 --model_name bert-wwm-ext-chinese \
124- --use_pretrained_emb False \
125- --teacher_path model/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
126- --vocab_path senta_word_dict_subset.txt
138+ --teacher_path pretrained_models/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
139+ --vocab_path senta_word_dict_subset.txt \
140+ --output_dir distilled_models/senta
141+
127142```
128143
129144``` shell
@@ -136,8 +151,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
136151 --dropout_prob 0.2 \
137152 --batch_size 128 \
138153 --model_name bert-base-uncased \
139- --use_pretrained_emb True \
140- --teacher_path model/SST-2/best_model_610/model_state.pdparams
154+ --embedding_name w2v.google_news.target.word-word.dim300.en \
155+ --output_dir distilled_models/SST-2 \
156+ --teacher_path pretrained_models/SST-2/best_model_610/model_state.pdparams
157+
141158```
142159
143160``` shell
@@ -149,24 +166,25 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
149166 --dropout_prob 0.2 \
150167 --batch_size 256 \
151168 --model_name bert-base-uncased \
152- --use_pretrained_emb True \
169+ --embedding_name w2v.google_news.target.word-word.dim300.en \
153170 --n_iter 10 \
154- --teacher_path model/QQP/best_model_17000/model_state.pdparams
171+ --output_dir distilled_models/QQP \
172+ --teacher_path pretrained_models/QQP/best_model_17000/model_state.pdparams
173+
155174```
156175
157176各参数的具体说明请参阅 ` args.py ` ,注意在训练不同任务时,需要调整对应的超参数。
158177
159178
160179## 蒸馏实验结果
161- 本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.2%、1.8%、1.4%的提升。
162-
163- | Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
164- | -------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
165- | teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
166- | Teacher | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
167- | Student | 0.853211 | 0.856171(acc)/0.806057(f1) | 0.920833 | 0.920800 |
168- | Distilled | 0.885321 | 0.874375(acc)/0.829581(f1) | 0.930000 | 0.935000 |
169-
180+ 本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.3%、1.9%、1.4%的提升。
181+
182+ | Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
183+ | ----------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
184+ | Teacher model | bert-base-uncased | bert-base-uncased | bert-base-chinese | bert-wwm-ext-chinese |
185+ | BERT-base | 0.930046 | 0.905813(acc)/0.873472(f1) | 0.951667 | 0.955000 |
186+ | Bi-LSTM | 0.854358 | 0.856616(acc)/0.799682(f1) | 0.920000 | 0.920000 |
187+ | Distilled Bi-LSTM | 0.887615 | 0.875216(acc)/0.831254(f1) | 0.932500 | 0.934167 |
170188
171189## 参考文献
172190
0 commit comments