Skip to content
165 changes: 164 additions & 1 deletion scheduled_sampling/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,164 @@
TBD
# Scheduled Sampling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

标题建议改成中文,下同所有的"Scheduled Sampling"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个好像还没有标准的中文翻译

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个例子不需要标准的中文翻译,用英文即可。暂时我也没有遇到广泛接受的中文翻译。


## 概述
序列生成任务的训练目标是在给定源输入的条件下,最大化目标序列的概率。训练时该模型将目标序列中的真实元素作为解码阶段每一步的输入,然后最大化下一个元素的概率。生成时上一步解码得到的元素被用作当前的输入,然后生成下一个元素。可见这种情况下训练阶段和生成阶段的解码层输入数据的概率分布并不一致。如果序列前面生成了错误的元素,后面的输入状态将会收到影响,而该误差会随着生成过程不断向后累积。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 如果有训练目标,就应该写生成目标。要不两个目标都可以不写,这儿建议可以全去掉,只讲训练和生成时的不同数据分布情况。
  2. “如果序列前面生成了错误的元素,后面的输入状态将会收到影响,而该误差会随着生成过程不断向后累积。”是引入Scheduled Sampling的原因么?如果不是,可以去掉。

Scheduled Sampling是一种解决训练和生成时输入数据分布不一致的方法。在训练早期该方法主要使用真实元素作为解码输入,以将模型从随机初始化的状态快速引导至一个合理的状态。随着训练的进行该方法会逐渐更多的使用生成元素作为解码输入,以解决数据分布不一致的问题。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 在训练早期该方法主要使用真实元素作为解码输入:真实元素应该是目标序列的真实元素
  2. 以将-》可以将
  3. 随着训练的进行,该方法XXX (全文注意分句)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

随着训练的进行该方法会逐渐更多的 --> 随着训练的进行该方法会逐渐更多地

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第4和第5行中间加空行,不然全部连在一起了。


## 算法简介
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

算法简介最好有图,现在的描述方式,小白用户看的很晕。

Scheduled Sampling主要应用在Sequence to Sequence模型的训练上,而生成阶段则不需要使用。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 主要应用在序列到序列模型的训练阶段,生成阶段不需要使用。
  2. Sequence to Sequence 改成 “序列到序列”,下同。

解码阶段在生成第`t`个元素时,标准Sequence to Sequence模型使用上一时刻的真实元素`y(t-1)`作为输入。设上一时刻生成的元素为`g(t-1)`,Scheduled Sampling算法会以一定概率使用`g(t-1)`作为解码输入。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句话中的 ty(t-1)等,全部替换成 $t$,$y(t-1)$ 等Latex 公式。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上一句话说,Scheduled sampling 主要引用在seq2seq 的训练,第二句又说“解码阶段在生成第t个元素时”,这里很容易令人迷惑。

设当前已经训练到了第`i`个mini-batch,在`t`时刻Scheduled Sampling以概率`epsilon_i`使用上一时刻的真实元素`y(t-1)`作为解码输入,以概率`1-epsilon_i`使用上一时刻生成的元素`g(t-1)`作为解码输入。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数学变量和公式,用 Latex 公式。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epsilon_i 是怎么来的?怎么和已处理样本关联起来的?

随着`i`的增大`epsilon_i`会不断减小,解码阶段将不断倾向于使用生成的元素作为输入,训练阶段和生成阶段的数据分布将变得越来越一致。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的解释看不出来 epsilon 是如何随着“已处理的样本数”而衰减,普通用户看到这里会很难理解。

`epsilon_i`可以使用不同的方式衰减,常见的方式有:

- 线性衰减:`epsilon_i=max(epsilon,k-c*i)`,其中`epsilon`限制`epsilon_i`的最小值,`k``c`控制线性衰减的幅度。
- 指数衰减:`epsilon_i=k^i`,其中`0<k<1``k`控制着指数衰减的幅度。
- 反向Sigmoid衰减:`epsilon_i=k/(k+exp(i/k))`,其中`k>1``k`同样控制衰减的幅度。

## 模型实现
由于Scheduled Sampling是对Sequence to Sequence模型的改进,其整体实现框架与Sequence to Sequence模型较为相似。为突出本文重点,这里仅介绍与Scheduled Sampling相关的部分,完整的代码见`scheduled_sampling.py`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与scheduled sampling相关的,包括:

  1. 采样概率如何衰减
  2. multiplex layer如何使用

都需要解释,这几组产生采样概率的函数,超参数设置原则?


首先定义控制衰减概率的类`RandomScheduleGenerator`,如下:
```python
import numpy as np
import math


class RandomScheduleGenerator:
"""
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
"""

def __init__(self, schedule_type, a, b):
"""
schduled_type: is the type of the decay. It supports constant, linear,
exponential, and inverse_sigmoid right now.
a: parameter of the decay (MUST BE DOUBLE)
b: parameter of the decay (MUST BE DOUBLE)
"""
self.schedule_type = schedule_type
self.a = a
self.b = b
self.data_processed_ = 0
self.schedule_computers = {
"constant": lambda a, b, d: a,
"linear": lambda a, b, d: max(a, 1 - d / b),
"exponential": lambda a, b, d: pow(a, d / b),
"inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
}
assert (self.schedule_type in self.schedule_computers)
self.schedule_computer = self.schedule_computers[self.schedule_type]

def getScheduleRate(self):
"""
Get the schedule sampling rate. Usually not needed to be called by the users
"""
return self.schedule_computer(self.a, self.b, self.data_processed_)

def processBatch(self, batch_size):
"""
Get a batch_size of sampled indexes. These indexes can be passed to a
MultiplexLayer to select from the grouth truth and generated samples
from the last time step.
"""
rate = self.getScheduleRate()
numbers = np.random.rand(batch_size)
indexes = (numbers >= rate).astype('int32').tolist()
self.data_processed_ += batch_size
return indexes
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样贴一段代码的效果和直接看代码是没啥区别的,请解释怎么使用,初始参数怎么设置。

其中`__init__`方法定义了几种不同的衰减概率,`processBatch`方法根据该概率进行采样,最终确定解码时是使用真实元素还是使用生成的元素。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

贴一段代码,在代码后面附上这样一句话是没有有效的解释,和直接看代码是没区别的,作为读者看完这句话是充满了疑惑的。

  1. 其中__init__方法定义了几种 --> 定义了几种?这几种怎么选择?参数怎么设置?请有效地与上文介绍进行管理,指代不请。

  2. processBatch方法根据该概率进行采样 --> 该概率指代上一句的__init__里面定义的吗?__init__里面接受超参数,采样概率是如何变化的?

  3. 最终确定解码时是使用真实元素还是使用生成的元素。 --> 怎么确定的?



这里对数据reader进行封装,加入从`RandomScheduleGenerator`采样得到的`true_token_flag`作为另一组数据输入,控制解码使用的元素。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里对数据reader进行封装 --> 请展开多写两句,这句话放在这里,为啥对reader进行封装?请不要让读者去想。。。
  2. 控制解码使用的元素。--> 这里并不涉及“解码”过程,通常把生成整个序列称之为解码。


```python
schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.75, 1000000 这两个值是怎么选择的,请在 README 中解释,否则,用户很难确定这两个值的设置从何而来。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面提了下超参数需要用户调优。后面调优后会替换这两个值,并说明这是调优后的结果。


def gen_schedule_data(reader):
"""
Creates a data reader for scheduled sampling.
Output from the iterator that created by original reader will be
appended with "true_token_flag" to indicate whether to use true token.
:param reader: the original reader.
:type reader: callable
:return: the new reader with the field "true_token_flag".
:rtype: callable
"""

def data_reader():
for src_ids, trg_ids, trg_ids_next in reader():
yield src_ids, trg_ids, trg_ids_next, \
[0] + schedule_generator.processBatch(len(trg_ids) - 1)

return data_reader
```

训练时`recurrent_group`每一步调用的解码函数如下:

```python
def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
true_token_flag):
"""
The decoder step for training.
:param enc_vec: the encoder vector for attention
:type enc_vec: Layer
:param enc_proj: the encoder projection for attention
:type enc_proj: Layer
:param true_word: the ground-truth target word
:type true_word: Layer
:param true_token_flag: the flag of using the ground-truth target word
:type true_token_flag: Layer
:return: the softmax output layer
:rtype: Layer
"""

decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)

context = paddle.networks.simple_attention(
encoded_sequence=enc_vec,
encoded_proj=enc_proj,
decoder_state=decoder_mem)

gru_out_memory = paddle.layer.memory(
name='gru_out', size=target_dict_dim)

generated_word = paddle.layer.max_id(input=gru_out_memory)

generated_word_emb = paddle.layer.embedding(
input=generated_word,
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))

current_word = paddle.layer.multiplex(
input=[true_token_flag, true_word, generated_word_emb])

with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
decoder_inputs += paddle.layer.full_matrix_projection(input=context)
decoder_inputs += paddle.layer.full_matrix_projection(
input=current_word)

gru_step = paddle.layer.gru_step(
name='gru_decoder',
input=decoder_inputs,
output_mem=decoder_mem,
size=decoder_size)

with paddle.layer.mixed(
name='gru_out',
size=target_dict_dim,
bias_attr=True,
act=paddle.activation.Softmax()) as out:
out += paddle.layer.full_matrix_projection(input=gru_step)

return out
```

该函数使用`memory``gru_out_memory`记忆不同时刻生成的元素,并使用`multiplex`层选择是否使用生成的元素作为解码输入。

### 训练结果待调参完成后补充
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一节先删去,可以merge 之后提PR修改。

39 changes: 36 additions & 3 deletions scheduled_sampling/scheduled_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def data_reader():


def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on parameters like in random_schedule_generator.py.

"""
The definition of the sequence to sequence model
:param source_dict_dim: the dictionary size of the source language
:type source_dict_dim: int
:param target_dict_dim: the dictionary size of the target language
:type target_dict_dim: int
:param is_generating: whether in generating mode
:type is_generating: Bool
:return: the last layer of the network
:rtype: Layer
"""
### Network Architecture
word_vector_dim = 512 # dimension of word vector
decoder_size = 512 # dimension of hidden unit in GRU Decoder network
Expand All @@ -41,9 +52,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
name='source_language_word',
type=paddle.data_type.integer_value_sequence(source_dict_dim))
src_embedding = paddle.layer.embedding(
input=src_word_id,
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_source_language_embedding'))
input=src_word_id, size=word_vector_dim)
src_forward = paddle.networks.simple_gru(
input=src_embedding, size=encoder_size)
src_backward = paddle.networks.simple_gru(
Expand All @@ -64,6 +73,19 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):

def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
true_token_flag):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on parameters like in random_schedule_generator.py.

"""
The decoder step for training.
:param enc_vec: the encoder vector for attention
:type enc_vec: Layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

:param enc_proj: the encoder projection for attention
:type enc_proj: Layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

:param true_word: the ground-truth target word
:type true_word: Layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

:param true_token_flag: the flag of using the ground-truth target word
:type true_token_flag: Layer
:return: the softmax output layer
:rtype: Layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer --> LayerOutput

"""

decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
Expand Down Expand Up @@ -107,6 +129,17 @@ def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
return out

def gru_decoder_with_attention_test(enc_vec, enc_proj, current_word):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on parameters like in random_schedule_generator.py.

"""
The decoder step for generating.
:param enc_vec: the encoder vector for attention
:type enc_vec: Layer
:param enc_proj: the encoder projection for attention
:type enc_proj: Layer
:param current_word: the previously generated word
:type current_word: Layer
:return: the softmax output layer
:rtype: Layer
"""

decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
Expand Down