Skip to content

Commit 82e8848

Browse files
authored
Merge pull request #29 from wwhu/ss-dev
Add scheduled sampling.
2 parents 0a307ac + 4a18101 commit 82e8848

File tree

5 files changed

+583
-1
lines changed

5 files changed

+583
-1
lines changed

scheduled_sampling/README.md

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,213 @@
1-
TBD
1+
# Scheduled Sampling
2+
3+
## 概述
4+
5+
序列生成任务的生成目标是在给定源输入的条件下,最大化目标序列的概率。训练时该模型将目标序列中的真实元素作为解码器每一步的输入,然后最大化下一个元素的概率。生成时上一步解码得到的元素被用作当前的输入,然后生成下一个元素。可见这种情况下训练阶段和生成阶段的解码器输入数据的概率分布并不一致。
6+
7+
Scheduled Sampling\[[1](#参考文献)\]是一种解决训练和生成时输入数据分布不一致的方法。在训练早期该方法主要使用目标序列中的真实元素作为解码器输入,可以将模型从随机初始化的状态快速引导至一个合理的状态。随着训练的进行,该方法会逐渐更多地使用生成的元素作为解码器输入,以解决数据分布不一致的问题。
8+
9+
标准的序列到序列模型中,如果序列前面生成了错误的元素,后面的输入状态将会收到影响,而该误差会随着生成过程不断向后累积。Scheduled Sampling以一定概率将生成的元素作为解码器输入,这样即使前面生成错误,其训练目标仍然是最大化真实目标序列的概率,模型会朝着正确的方向进行训练。因此这种方式增加了模型的容错能力。
10+
11+
## 算法简介
12+
Scheduled Sampling主要应用在序列到序列模型的训练阶段,而生成阶段则不需要使用。
13+
14+
训练阶段解码器在最大化第$t$个元素概率时,标准序列到序列模型使用上一时刻的真实元素$y_{t-1}$作为输入。设上一时刻生成的元素为$g_{t-1}$,Scheduled Sampling算法会以一定概率使用$g_{t-1}$作为解码器输入。
15+
16+
设当前已经训练到了第$i$个mini-batch,Scheduled Sampling定义了一个概率$\epsilon_i$控制解码器的输入。$\epsilon_i$是一个随着$i$增大而衰减的变量,常见的定义方式有:
17+
18+
- 线性衰减:$\epsilon_i=max(\epsilon,k-c*i)$,其中$\epsilon$限制$\epsilon_i$的最小值,$k$和$c$控制线性衰减的幅度。
19+
20+
- 指数衰减:$\epsilon_i=k^i$,其中$0<k<1$,$k$控制着指数衰减的幅度。
21+
22+
- 反向Sigmoid衰减:$\epsilon_i=k/(k+exp(i/k))$,其中$k>1$,$k$同样控制衰减的幅度。
23+
24+
图1给出了这三种方式的衰减曲线,
25+
26+
<p align="center">
27+
<img src="img/decay.jpg" width="50%" align="center"><br>
28+
图1. 线性衰减、指数衰减和反向Sigmoid衰减的衰减曲线
29+
</p>
30+
31+
如图2所示,在解码器的$t$时刻Scheduled Sampling以概率$\epsilon_i$使用上一时刻的真实元素$y_{t-1}$作为解码器输入,以概率$1-\epsilon_i$使用上一时刻生成的元素$g_{t-1}$作为解码器输入。从图1可知随着$i$的增大$\epsilon_i$会不断减小,解码器将不断倾向于使用生成的元素作为输入,训练阶段和生成阶段的数据分布将变得越来越一致。
32+
33+
<p align="center">
34+
<img src="img/Scheduled_Sampling.jpg" width="50%" align="center"><br>
35+
图2. Scheduled Sampling选择不同元素作为解码器输入示意图
36+
</p>
37+
38+
## 模型实现
39+
40+
由于Scheduled Sampling是对序列到序列模型的改进,其整体实现框架与序列到序列模型较为相似。为突出本文重点,这里仅介绍与Scheduled Sampling相关的部分,完整的代码见`scheduled_sampling.py`
41+
42+
首先导入需要的包,并定义控制衰减概率的类`RandomScheduleGenerator`,如下:
43+
44+
```python
45+
import numpy as np
46+
import math
47+
48+
49+
class RandomScheduleGenerator:
50+
"""
51+
The random sampling rate for scheduled sampling algoithm, which uses devcayed
52+
sampling rate.
53+
54+
"""
55+
...
56+
```
57+
58+
下面将分别定义类`RandomScheduleGenerator``__init__``getScheduleRate``processBatch`三个方法。
59+
60+
`__init__`方法对类进行初始化,其`schedule_type`参数指定了使用哪种衰减方式,可选的方式有`constant``linear``exponential``inverse_sigmoid``constant`指对所有的mini-batch使用固定的$\epsilon_i$,`linear`指线性衰减方式,`exponential`表示指数衰减方式,`inverse_sigmoid`表示反向Sigmoid衰减。`__init__`方法的参数`a``b`表示衰减方法的参数,需要在验证集上调优。`self.schedule_computers`将衰减方式映射为计算$\epsilon_i$的函数。最后一行根据`schedule_type`将选择的衰减函数赋给`self.schedule_computer`变量。
61+
62+
```python
63+
def __init__(self, schedule_type, a, b):
64+
"""
65+
schduled_type: is the type of the decay. It supports constant, linear,
66+
exponential, and inverse_sigmoid right now.
67+
a: parameter of the decay (MUST BE DOUBLE)
68+
b: parameter of the decay (MUST BE DOUBLE)
69+
"""
70+
self.schedule_type = schedule_type
71+
self.a = a
72+
self.b = b
73+
self.data_processed_ = 0
74+
self.schedule_computers = {
75+
"constant": lambda a, b, d: a,
76+
"linear": lambda a, b, d: max(a, 1 - d / b),
77+
"exponential": lambda a, b, d: pow(a, d / b),
78+
"inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
79+
}
80+
assert (self.schedule_type in self.schedule_computers)
81+
self.schedule_computer = self.schedule_computers[self.schedule_type]
82+
```
83+
84+
`getScheduleRate`根据衰减函数和已经处理的数据量计算$\epsilon_i$。
85+
86+
```python
87+
def getScheduleRate(self):
88+
"""
89+
Get the schedule sampling rate. Usually not needed to be called by the users
90+
"""
91+
return self.schedule_computer(self.a, self.b, self.data_processed_)
92+
93+
```
94+
95+
`processBatch`方法根据概率值$\epsilon_i$进行采样,得到`indexes``indexes`中每个元素取值为`0`的概率为$\epsilon_i$,取值为`1`的概率为$1-\epsilon_i$。`indexes`决定了解码器的输入是真实元素还是生成的元素,取值为`0`表示使用真实元素,取值为`1`表示使用生成的元素。
96+
97+
```python
98+
def processBatch(self, batch_size):
99+
"""
100+
Get a batch_size of sampled indexes. These indexes can be passed to a
101+
MultiplexLayer to select from the grouth truth and generated samples
102+
from the last time step.
103+
"""
104+
rate = self.getScheduleRate()
105+
numbers = np.random.rand(batch_size)
106+
indexes = (numbers >= rate).astype('int32').tolist()
107+
self.data_processed_ += batch_size
108+
return indexes
109+
```
110+
111+
Scheduled Sampling需要在序列到序列模型的基础上增加一个输入`true_token_flag`,以控制解码器输入。
112+
113+
```python
114+
true_token_flags = paddle.layer.data(
115+
name='true_token_flag',
116+
type=paddle.data_type.integer_value_sequence(2))
117+
```
118+
119+
这里还需要对原始reader进行封装,增加`true_token_flag`的数据生成器。下面以线性衰减为例说明如何调用上面定义的`RandomScheduleGenerator`产生`true_token_flag`的输入数据。
120+
121+
```python
122+
schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000)
123+
124+
def gen_schedule_data(reader):
125+
"""
126+
Creates a data reader for scheduled sampling.
127+
128+
Output from the iterator that created by original reader will be
129+
appended with "true_token_flag" to indicate whether to use true token.
130+
131+
:param reader: the original reader.
132+
:type reader: callable
133+
134+
:return: the new reader with the field "true_token_flag".
135+
:rtype: callable
136+
"""
137+
138+
def data_reader():
139+
for src_ids, trg_ids, trg_ids_next in reader():
140+
yield src_ids, trg_ids, trg_ids_next, \
141+
[0] + schedule_generator.processBatch(len(trg_ids) - 1)
142+
143+
return data_reader
144+
```
145+
146+
这段代码在原始输入数据(即源序列元素`src_ids`、目标序列元素`trg_ids`和目标序列下一个元素`trg_ids_next`)后追加了控制解码器输入的数据。由于解码器第一个元素是序列开始符,因此将追加的数据第一个元素设置为`0`,表示解码器第一步始终使用真实目标序列的第一个元素(即序列开始符)。
147+
148+
训练时`recurrent_group`每一步调用的解码器函数如下:
149+
150+
```python
151+
def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
152+
true_token_flag):
153+
"""
154+
The decoder step for training.
155+
:param enc_vec: the encoder vector for attention
156+
:type enc_vec: LayerOutput
157+
:param enc_proj: the encoder projection for attention
158+
:type enc_proj: LayerOutput
159+
:param true_word: the ground-truth target word
160+
:type true_word: LayerOutput
161+
:param true_token_flag: the flag of using the ground-truth target word
162+
:type true_token_flag: LayerOutput
163+
:return: the softmax output layer
164+
:rtype: LayerOutput
165+
"""
166+
167+
decoder_mem = paddle.layer.memory(
168+
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
169+
170+
context = paddle.networks.simple_attention(
171+
encoded_sequence=enc_vec,
172+
encoded_proj=enc_proj,
173+
decoder_state=decoder_mem)
174+
175+
gru_out_memory = paddle.layer.memory(
176+
name='gru_out', size=target_dict_dim)
177+
178+
generated_word = paddle.layer.max_id(input=gru_out_memory)
179+
180+
generated_word_emb = paddle.layer.embedding(
181+
input=generated_word,
182+
size=word_vector_dim,
183+
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
184+
185+
current_word = paddle.layer.multiplex(
186+
input=[true_token_flag, true_word, generated_word_emb])
187+
188+
with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
189+
decoder_inputs += paddle.layer.full_matrix_projection(input=context)
190+
decoder_inputs += paddle.layer.full_matrix_projection(
191+
input=current_word)
192+
193+
gru_step = paddle.layer.gru_step(
194+
name='gru_decoder',
195+
input=decoder_inputs,
196+
output_mem=decoder_mem,
197+
size=decoder_size)
198+
199+
with paddle.layer.mixed(
200+
name='gru_out',
201+
size=target_dict_dim,
202+
bias_attr=True,
203+
act=paddle.activation.Softmax()) as out:
204+
out += paddle.layer.full_matrix_projection(input=gru_step)
205+
206+
return out
207+
```
208+
209+
该函数使用`memory``gru_out_memory`记忆上一时刻生成的元素,根据`gru_out_memory`选择概率最大的词语`generated_word`作为生成的词语。`multiplex`层会在真实元素`true_word`和生成的元素`generated_word`之间做出选择,并将选择的结果作为解码器输入。`multiplex`层使用了三个输入,分别为`true_token_flag``true_word``generated_word_emb`。对于这三个输入中每个元素,若`true_token_flag`中的值为`0`,则`multiplex`层输出`true_word`中的相应元素;若`true_token_flag`中的值为`1`,则`multiplex`层输出`generated_word_emb`中的相应元素。
210+
211+
## 参考文献
212+
213+
[1] Bengio S, Vinyals O, Jaitly N, et al. [Scheduled sampling for sequence prediction with recurrent neural networks](http://papers.nips.cc/paper/5956-scheduled-sampling-for-sequence-prediction-with-recurrent-neural-networks)//Advances in Neural Information Processing Systems. 2015: 1171-1179.
59.2 KB
Loading

scheduled_sampling/img/decay.jpg

44.6 KB
Loading
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import math
3+
4+
5+
class RandomScheduleGenerator:
6+
"""
7+
The random sampling rate for scheduled sampling algoithm, which uses devcayed
8+
sampling rate.
9+
"""
10+
11+
def __init__(self, schedule_type, a, b):
12+
"""
13+
schduled_type: is the type of the decay. It supports constant, linear,
14+
exponential, and inverse_sigmoid right now.
15+
a: parameter of the decay (MUST BE DOUBLE)
16+
b: parameter of the decay (MUST BE DOUBLE)
17+
"""
18+
self.schedule_type = schedule_type
19+
self.a = a
20+
self.b = b
21+
self.data_processed_ = 0
22+
self.schedule_computers = {
23+
"constant": lambda a, b, d: a,
24+
"linear": lambda a, b, d: max(a, 1 - d / b),
25+
"exponential": lambda a, b, d: pow(a, d / b),
26+
"inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
27+
}
28+
assert (self.schedule_type in self.schedule_computers)
29+
self.schedule_computer = self.schedule_computers[self.schedule_type]
30+
31+
def getScheduleRate(self):
32+
"""
33+
Get the schedule sampling rate. Usually not needed to be called by the users
34+
"""
35+
return self.schedule_computer(self.a, self.b, self.data_processed_)
36+
37+
def processBatch(self, batch_size):
38+
"""
39+
Get a batch_size of sampled indexes. These indexes can be passed to a
40+
MultiplexLayer to select from the grouth truth and generated samples
41+
from the last time step.
42+
"""
43+
rate = self.getScheduleRate()
44+
numbers = np.random.rand(batch_size)
45+
indexes = (numbers >= rate).astype('int32').tolist()
46+
self.data_processed_ += batch_size
47+
return indexes

0 commit comments

Comments
 (0)