-
Notifications
You must be signed in to change notification settings - Fork 876
新增fused transformer encoder 中文文档 #4022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
TCChenlong
merged 18 commits into
PaddlePaddle:develop
from
zkh2016:fused_transformer_layer_doc
Oct 28, 2021
Merged
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
3b062c7
修改错别字
zkh2016 0606c9b
备注CPU不支持float16
zkh2016 5726d08
update example
zkh2016 f3c462d
Merge branch 'develop' of github.com:PaddlePaddle/docs into develop
8ef9a10
add fused_feedforward
795cd8a
add fused_feedforward
a9f69a7
add fused_feedforward
528a4dc
opt the description
02d8c0b
update docs
eba59ea
update docs
a028800
update doc
2187f4d
move fused_feedforward docs position
26ed660
add incubate/nn/layer
e1ec8bb
move fused transfomrer doc to incubate/nn/
c9b4e78
modify the doc
94991ba
Merge branch 'develop' into fused_transformer_layer_doc
6d1f1f5
modify fused transformer encoder layer doc
9b9f1bb
Merge branch 'fused_transformer_layer_doc' of github.com:zkh2016/docs…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| .. _cn_api_incubate_nn_FusedFeedForward: | ||
|
|
||
| FusedFeedForward | ||
| ------------------------------- | ||
| .. py:class:: paddle.incubate.nn.FusedFeedForward(d_model, dim_feedforward, dropout_rate=0.1, activation='relu', act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None) | ||
|
|
||
| 这是一个调用融合算子fused_feedforward(参考:ref:`cn_api_incubate_nn_functional_fused_feedforward` )。 | ||
|
|
||
|
|
||
| 参数 | ||
| ::::::::: | ||
| - **d_model** (int) - 输入输出的维度。 | ||
| - **dim_feedforward** (int) - 前馈神经网络中隐藏层的大小。 | ||
| - **dropout_rate** (float,可选) - 对本层的输出进行处理的dropout值, 置零的概率。默认值:0.1。 | ||
| - **activation** (str,可选) - 激活函数。默认值:``relu``。 | ||
| - **act_dropout_rate** (float,可选) - 激活函数后的dropout置零的概率。如果为 `None` 则 `act_dropout_rate = dropout_rate` 。默认值: `None` 。 | ||
|
||
| - **normalize_before** (bool, 可选) - 设置对输入输出的处理。如果为 `True` ,则对输入进行层标准化(Layer Normalization),否则(即为 `False` ),则对输入不进行处理,而是在输出前进行标准化。默认值: `False` 。 | ||
| - **weight_attr** (ParamAttr,可选) - 指定权重参数属性的对象。默认值: `None` ,表示使用默认的权重参数属性,即使用0进行初始化。具体用法请参见 :ref:`cn_api_fluid_ParamAttr` 。 | ||
| - **bias_attr** (ParamAttr|bool,可选)- 指定偏置参数属性的对象。如果该参数值是 `ParamAttr` ,则使用 `ParamAttr` 。如果该参数为 `bool` 类型,只支持为 `False` ,表示没有偏置参数。默认值为None,表示使用默认的偏置参数属性。具体用法请参见 :ref:`cn_api_fluid_ParamAttr` 。 | ||
|
|
||
| 返回 | ||
| ::::::::: | ||
| - Tensor, 输出Tensor,数据类型与`x`一样。 | ||
|
|
||
| 代码示例 | ||
| :::::::::: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # required: gpu | ||
| import paddle | ||
| from paddle.incubate.nn import FusedFeedForward | ||
|
|
||
| fused_feedforward_layer = FusedFeedForward(8, 8) | ||
| x = paddle.rand((1, 8, 8)) | ||
| out = fused_feedforward_layer(x) | ||
| print(out.numpy().shape) | ||
| # (1, 8, 8) | ||
43 changes: 43 additions & 0 deletions
43
docs/api/paddle/incubate/nn/FusedTransformerEncoderLayer_cn.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| .. _cn_api_incubate_nn_FusedTransformerEncoderLayer: | ||
|
|
||
| FusedTransformerEncoderLayer | ||
| ------------------------------- | ||
| .. py:class:: paddle.incubate.nn.FusedTransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout_rate=0.1, activation='relu', attn_dropout_rate=None, act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None) | ||
|
|
||
|
|
||
| FusedTransformer编码器层由两个子层组成:多头自注意力机制和前馈神经网络。如果 `normalize_before` 为 `True` ,则对每个子层的输入进行层标准化(Layer Normalization),对每个子层的输出进行dropout和残差连接(residual connection)。否则(即 `normalize_before` 为 `False` ),则对每个子层的输入不进行处理,只对每个子层的输出进行dropout、残差连接(residual connection)和层标准化(Layer Normalization)。 | ||
|
|
||
|
|
||
| 参数 | ||
| ::::::::: | ||
| - **d_model** (int) - 输入输出的维度。 | ||
| - **nhead** (int) - multi-head attention(MHA)的Head数量。 | ||
| - **dim_feedforward** (int) - 前馈神经网络中隐藏层的大小。 | ||
| - **dropout_rate** (float,可选) - 对两个子层的输出进行处理的dropout值, 置零的概率。默认值:0.1。 | ||
| - **activation** (str,可选) - 前馈神经网络的激活函数。默认值:``relu``。 | ||
| - **attn_dropout_rate** (float,可选) - MHA中对注意力目标的随机失活率。如果为 `None` 则 `attn_dropout = dropout` 。默认值: `None` 。 | ||
| - **act_dropout_rate** (float,可选) - 前馈神经网络的激活函数后的dropout置零的概率。如果为 `None` 则 `act_dropout_rate = dropout_rate` 。默认值: `None` 。 | ||
| - **normalize_before** (bool, 可选) - 设置对每个子层的输入输出的处理。如果为 `True` ,则对每个子层的输入进行层标准化(Layer Normalization),否则(即为 `False` ),则对每个子层的输入不进行处理,而是在子层的输出前进行标准化。默认值: `False` 。 | ||
| - **weight_attr** (ParamAttr|tuple,可选) - 指定权重参数属性的对象。如果是 `tuple` ,MHA的权重参数属性使用 `weight_attr[0]` ,前馈神经网络的权重参数属性使用 `weight_attr[1]` 。如果参数值是 `ParamAttr` ,则MHA和前馈神经网络的权重参数属性都使用 `ParamAttr` 。默认值: `None` ,表示使用默认的权重参数属性。具体用法请参见 :ref:`cn_api_fluid_ParamAttr` 。 | ||
| - **bias_attr** (ParamAttr|tuple|bool,可选)- 指定偏置参数属性的对象。如果是 `tuple` ,MHA的偏置参数属性使用 `bias_attr[0]` ,前馈神经网络的偏置参数属性使用 `bias_attr[1]` 。如果该参数值是 `ParamAttr` ,则MHA和前馈神经网络的偏置参数属性都使用 `ParamAttr` 。如果该参数为 `bool` 类型,只支持为 `False` ,表示没有偏置参数。默认值为None,表示使用默认的偏置参数属性。具体用法请参见 :ref:`cn_api_fluid_ParamAttr` 。 | ||
|
|
||
|
|
||
| 返回 | ||
| ::::::::: | ||
| - Tensor, 输出Tensor,数据类型与`x`一样。 | ||
|
|
||
| 代码示例 | ||
| :::::::::: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # required: gpu | ||
| import paddle | ||
| from paddle.incubate.nn import FusedTransformerEncoderLayer | ||
|
|
||
| # encoder input: [batch_size, src_len, d_model] | ||
| enc_input = paddle.rand((2, 4, 128)) | ||
| # self attention mask: [batch_size, n_head, src_len, src_len] | ||
| attn_mask = paddle.rand((2, 2, 4, 4)) | ||
| encoder_layer = FusedTransformerEncoderLayer(128, 2, 512) | ||
| enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128] |
59 changes: 59 additions & 0 deletions
59
docs/api/paddle/incubate/nn/functional/fused_feedforward_cn.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| .. _cn_api_incubate_nn_functional_fused_feedforward: | ||
|
|
||
| fused_feedforward | ||
| ------------------------------- | ||
|
|
||
| .. py:function:: paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight, linear1_bias=None, linear2_bias=None, ln1_scale=None, ln1_bias=None, ln2_scale=None, ln2_bias=None, dropout1_rate=0.5, dropout2_rate=0.5,activation="relu", ln1_epsilon=1e-5, ln2_epsilon=1e-5, pre_layer_norm=False, name=None): | ||
|
|
||
| 这是一个融合算子,该算子是对transformer模型中feed forward层的多个算子进行融合,该算子只支持在GPU下运行,该算子与如下伪代码表达一样的功能: | ||
|
|
||
| .. code-block:: ipython | ||
|
|
||
| residual = src; | ||
| if pre_layer_norm: | ||
| src = layer_norm(src) | ||
| src = linear(dropout(activation(dropout(linear(src))))) | ||
| if not pre_layer_norm: | ||
| src = layer_norm(out) | ||
|
|
||
| 参数 | ||
| ::::::::: | ||
| - **x** (Tensor) - 输入Tensor,数据类型支持float16, float32 和float64, 输入的形状是`[batch_size, sequence_length, d_model]`。 | ||
| - **linear1_weight** (Tensor) - 第一个linear算子的权重数据,数据类型与`x`一样,形状是`[d_model, dim_feedforward]`。 | ||
| - **linear2_weight** (Tensor) - 第二个linear算子的权重数据,数据类型与`x`一样,形状是`[dim_feedforward, d_model]`。 | ||
| - **linear1_bias** (Tensor, 可选) - 第一个linear算子的偏置数据,数据类型与`x`一样,形状是`[dim_feedforward]`。默认值为None。 | ||
| - **linear2_bias** (Tensor, 可选) - 第二个linear算子的偏置数据,数据类型与`x`一样,形状是`[d_model]`。默认值为None。 | ||
| - **ln1_scale** (Tensor, 可选) - 第一个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。 | ||
| - **ln1_bias** (Tensor, 可选) - 第一个layer_norm算子的偏置数据,数据类型和`ln1_scale`一样, 形状是`[d_model]`。默认值为None。 | ||
| - **ln2_scale** (Tensor, 可选) - 第二个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。 | ||
| - **ln2_bias** (Tensor, 可选) - 第二个layer_norm算子的偏置数据,数据类型和`ln2_scale`一样, 形状是`[d\_model]`。默认值为None。 | ||
| - **dropout1_rate** (float, 可选) - 第一个dropout算子置零的概率。默认是0.5。 | ||
| - **dropout2_rate** (float, 可选) - 第二个dropout算子置零的概率。默认是0.5。 | ||
| - **activation** (string, 可选) - 激活函数。默认值是relu。 | ||
| - **ln1_epsilon** (float, 可选) - 一个很小的浮点数,被第一个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。 | ||
| - **ln2_epsilon** (float, 可选) - 一个很小的浮点数,被第二个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。 | ||
| - **pre_layer_norm** (bool, 可选) - 在预处理阶段加上layer_norm,或者在后处理阶段加上layer_norm。默认值是False。 | ||
| - **name** (string, 可选) – fused_feedforward的名称, 默认值为None。更多信息请参见 :ref:`api_guide_Name` 。 | ||
|
|
||
| 返回 | ||
| ::::::::: | ||
| - Tensor, 输出Tensor,数据类型与`x`一样。 | ||
|
|
||
| 代码示例 | ||
| :::::::::: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # required: gpu | ||
| import paddle | ||
| import numpy as np | ||
| x_data = np.random.random((1, 8, 8)).astype("float32") | ||
| linear1_weight_data = np.random.random((8, 8)).astype("float32") | ||
| linear2_weight_data = np.random.random((8, 8)).astype("float32") | ||
| x = paddle.to_tensor(x_data) | ||
| linear1_weight = paddle.to_tensor(linear1_weight_data) | ||
| linear2_weight = paddle.to_tensor(linear2_weight_data) | ||
| out = paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight) | ||
| print(out.numpy().shape) | ||
| # (1, 8, 8) | ||
|
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里链接格式有问题。应该在参考后面加一个空格
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done