@@ -76,8 +76,18 @@ def fused_feedforward(x,
7676 ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
7777 ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
7878 pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
79- training (bool): A flag indicating whether it is in train phrase or not. Default True.
80- mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
79+ training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
80+ mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
81+
82+ 1. upscale_in_train(default), upscale the output at training time
83+
84+ - train: out = input * mask / ( 1.0 - p )
85+ - inference: out = input
86+
87+ 2. downscale_in_infer, downscale the output at inference
88+
89+ - train: out = input * mask
90+ - inference: out = input * (1.0 - p)
8191 name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
8292
8393 Returns:
@@ -245,7 +255,10 @@ def fused_multi_head_attention(x,
245255 out = out * v
246256 out = transpose(out, perm=[0, 2, 1, 3])
247257 out = out_linear(out)
248- out = layer_norm(x + dropout(linear_bias + out))
258+ if pre_layer_norm:
259+ out = x + dropout(linear_bias + out)
260+ else:
261+ out = layer_norm(x + dropout(linear_bias + out))
249262
250263 Parameters:
251264 x (Tensor): The input tensor of fused_multi_head_attention. The shape is
@@ -279,7 +292,7 @@ def fused_multi_head_attention(x,
279292 ln_epsilon (float, optional): Small float value added to denominator of layer_norm
280293 to avoid dividing by zero. Default is 1e-5.
281294 training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
282- mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
295+ mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
283296
284297 1. upscale_in_train(default), upscale the output at training time
285298
0 commit comments