Skip to content

Commit a9e7a85

Browse files
authored
Fix attn_bias_add bug. (#37147)
fused_attention_op的实现中,使用了bias_add,且其实现是通过使用kernel primitive来实现的,之后kernel primitive的WriteData api接口及函数内部实现发生了更改,将判断越界的逻辑移到了template的参数中,使得调用的分支有错误,产生了越界赋值操作,污染了别的显存空间的内容。具体表现为:test_fused_attention_op_api.py 单次执行基本上不会报错,多次循环执行不同shape的输入,结果计算不对,具有偶发性,bug不易察觉。
1 parent c5ccff7 commit a9e7a85

File tree

5 files changed

+110
-46
lines changed

5 files changed

+110
-46
lines changed

paddle/fluid/operators/fused/attn_gemm.h

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ limitations under the License. */
1111

1212
#pragma once
1313

14-
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
1514
#include "paddle/fluid/operators/math/blas.h"
1615
#include "paddle/fluid/platform/float16.h"
1716

17+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
18+
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19+
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
20+
1821
namespace paddle {
1922
namespace operators {
2023

@@ -36,8 +39,10 @@ class AttnMatMul {
3639

3740
~AttnMatMul() {}
3841

39-
void ComputeForward(const T* weight_data, const T* input_data,
40-
const T* bias_data, T* output_data, T* bias_out_data) {
42+
void ComputeForward(const framework::Tensor* weight,
43+
const framework::Tensor* input,
44+
const framework::Tensor* bias, framework::Tensor* output,
45+
framework::Tensor* bias_out) {
4146
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
4247
// here: (transa, transb): nt, input * weight.
4348
CBLAS_TRANSPOSE transA = CblasNoTrans;
@@ -54,16 +59,25 @@ class AttnMatMul {
5459
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
5560
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
5661
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
57-
input_data, weight_data, beta, output_data);
62+
input->data<T>(), weight->data<T>(), beta, output->data<T>());
5863
if (compute_bias_) {
5964
// compute output + bias
60-
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
61-
bias_data, bias_out_data);
65+
std::vector<const Tensor*> ins;
66+
std::vector<Tensor*> outs;
67+
ins.emplace_back(output);
68+
ins.emplace_back(bias);
69+
outs.emplace_back(bias_out);
70+
int elewise_add_axis = -1;
71+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
72+
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
6273
}
6374
}
6475

65-
void ComputeBackward(const T* input, const T* weight, const T* d_output,
66-
T* d_input, T* d_weight, T* d_bias) {
76+
void ComputeBackward(const framework::Tensor* input,
77+
const framework::Tensor* weight,
78+
const framework::Tensor* d_output,
79+
framework::Tensor* d_input, framework::Tensor* d_weight,
80+
framework::Tensor* d_bias) {
6781
T alpha = static_cast<T>(1.0);
6882
T beta = static_cast<T>(0.0);
6983
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
@@ -81,11 +95,11 @@ class AttnMatMul {
8195

8296
T* dB_input_1_ptr = nullptr;
8397
T* dB_input_2_ptr = nullptr;
84-
T* dB_output_ptr = d_weight;
98+
T* dB_output_ptr = d_weight->data<T>();
8599

86100
T* dA_input_1_ptr = nullptr;
87101
T* dA_input_2_ptr = nullptr;
88-
T* dA_output_ptr = d_input;
102+
T* dA_output_ptr = d_input->data<T>();
89103

90104
if (!transA_) {
91105
// fw: gemm-nt
@@ -104,10 +118,10 @@ class AttnMatMul {
104118
dA_n = input_size_;
105119
dA_k = output_size_;
106120

107-
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
108-
input, beta, dB_output_ptr);
109-
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
110-
weight, beta, dA_output_ptr);
121+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
122+
d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
123+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
124+
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
111125
} else { // fw: gemm-nn
112126
// bw: gemm-tn, dB = A^t * dC
113127
dB_transA = CblasTrans;
@@ -123,10 +137,10 @@ class AttnMatMul {
123137
dA_n = input_size_;
124138
dA_k = output_size_;
125139

126-
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
127-
d_output, beta, dB_output_ptr);
128-
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
129-
weight, beta, dA_output_ptr);
140+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
141+
input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
142+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
143+
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
130144
}
131145
} else if (transB_) {
132146
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -138,7 +152,27 @@ class AttnMatMul {
138152
"parameters."));
139153
}
140154
if (compute_bias_) {
141-
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
155+
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
156+
const auto input_dims = d_output->dims();
157+
const auto output_dims = d_bias->dims();
158+
bool support_case_1 =
159+
(input_dims.size() == 5 && output_dims.size() == 3 &&
160+
(input_dims[2] == output_dims[0]) &&
161+
(input_dims[3] == output_dims[1]) &&
162+
(input_dims[4] == output_dims[2]));
163+
bool support_case_2 =
164+
(input_dims.size() == 3 && output_dims.size() == 1 &&
165+
(input_dims[2] == output_dims[0]));
166+
if (support_case_1 || support_case_2) {
167+
gpuStream_t stream = dev_ctx_.stream();
168+
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
169+
stream);
170+
} else {
171+
PADDLE_THROW(platform::errors::InvalidArgument(
172+
"Only support reduce when the input dims are [0,1,2,3,4] and "
173+
"output is [2,3,4]"
174+
"or input is [0,1,2] and output is [2]."));
175+
}
142176
}
143177
}
144178

paddle/fluid/operators/fused/fused_attention_op.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
170170

171171
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
172172
ln_out_data, ln_mean_data, ln_var_data);
173-
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
174-
qkv_out_data, qkv_bias_out_data);
173+
qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
174+
qkv_bias_out);
175175
} else {
176-
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
177-
qkv_out_data, qkv_bias_out_data);
176+
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
177+
qkv_bias_out);
178178
}
179179
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
180180
qk_out, src_mask_out, softmax_out,
@@ -184,8 +184,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
184184
// fmha_out: [batch_size, seq_len, num_head, head_dim]
185185
// weight: [embed_dim, embed_dim]
186186
// out_linear_out: [batch_size, seq_len, embed_dim]
187-
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
188-
nullptr, out_linear_out_data, nullptr);
187+
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
188+
out_linear_out, nullptr);
189189
if (pre_layer_norm) {
190190
// output = (residual + dropout(input + bias))
191191
fused_dropout_layernorm_helper.ResidualDropoutBias(
@@ -401,9 +401,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
401401
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
402402
}
403403

404-
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
405-
d_out_linear_out_data, d_fmha_out_data,
406-
d_out_linear_weight_data, nullptr);
404+
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
405+
d_out_linear_out, d_fmha_out,
406+
d_out_linear_weight, nullptr);
407+
407408
fmha_ref_compute.ComputeBackward(
408409
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
409410
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
@@ -432,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
432433
(d_ln_bias == nullptr ? nullptr
433434
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
434435

435-
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
436-
d_qkv_bias_out_data, d_ln_out_data,
437-
d_qkv_weight_data, d_qkv_bias_data);
436+
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
437+
d_qkv_weight, d_qkv_bias);
438438
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
439439
ln_mean_data, ln_var_data, d_x_data,
440440
d_ln_scale_data, d_ln_bias_data);
441441
} else {
442-
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
443-
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
442+
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
443+
d_qkv_weight, d_qkv_bias);
444444
}
445445
// gradient accumulation
446446
std::vector<const Tensor *> ins;

python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
8989
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
9090
qkv_weight.shape[2] * qkv_weight.shape[3])
9191

92+
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
93+
qkv_bias.shape[2])
9294
if (pre_layer_norm):
9395
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
9496
qkv = fc(ln_out, qkv_weight)
97+
qkv_bias_out = qkv + qkv_bias
9598
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
9699
else:
97100
query = query.reshape(batch_size * seq_len, embed_dim)
98101
qkv = fc(query, qkv_weight)
102+
qkv_bias_out = qkv + qkv_bias
99103
query = query.reshape(batch_size, seq_len, embed_dim)
100104

101-
qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim)
105+
qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
106+
head_dim)
102107
# q*k^t
103-
qkv = qkv.transpose(
108+
qkv_bias_out = qkv_bias_out.transpose(
104109
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
105-
qkv = qkv.transpose(
110+
qkv_bias_out = qkv_bias_out.transpose(
106111
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim
107112

108-
q = qkv[0:1, ::]
113+
q = qkv_bias_out[0:1, ::]
109114
q = q.reshape(batch_size, num_head, seq_len, head_dim)
110-
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
115+
k = qkv_bias_out[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
111116
k = k.reshape(batch_size, num_head, seq_len, head_dim)
112-
v = qkv[2::]
117+
v = qkv_bias_out[2::]
113118
v = v.reshape(batch_size, num_head, seq_len, head_dim)
114119

115120
k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
@@ -200,6 +205,8 @@ def run_imperative(self):
200205
self.embed_dim, self.num_heads, self.dropout_prob,
201206
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
202207
self.need_weight, self.weight_attr, self.bias_attr)
208+
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
209+
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
203210
out = fused_attn(
204211
paddle.to_tensor(self.query),
205212
paddle.to_tensor(self.query),

python/paddle/incubate/nn/functional/fused_transformer.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,18 @@ def fused_feedforward(x,
7676
ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
7777
ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
7878
pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
79-
training (bool): A flag indicating whether it is in train phrase or not. Default True.
80-
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
79+
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
80+
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
81+
82+
1. upscale_in_train(default), upscale the output at training time
83+
84+
- train: out = input * mask / ( 1.0 - p )
85+
- inference: out = input
86+
87+
2. downscale_in_infer, downscale the output at inference
88+
89+
- train: out = input * mask
90+
- inference: out = input * (1.0 - p)
8191
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
8292
8393
Returns:
@@ -245,7 +255,10 @@ def fused_multi_head_attention(x,
245255
out = out * v
246256
out = transpose(out, perm=[0, 2, 1, 3])
247257
out = out_linear(out)
248-
out = layer_norm(x + dropout(linear_bias + out))
258+
if pre_layer_norm:
259+
out = x + dropout(linear_bias + out)
260+
else:
261+
out = layer_norm(x + dropout(linear_bias + out))
249262
250263
Parameters:
251264
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
@@ -278,8 +291,18 @@ def fused_multi_head_attention(x,
278291
0 for no dropout. Default 0.5.
279292
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
280293
to avoid dividing by zero. Default is 1e-5.
281-
training (bool): A flag indicating whether it is in train phrase or not. Default True.
282-
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
294+
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
295+
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
296+
297+
1. upscale_in_train(default), upscale the output at training time
298+
299+
- train: out = input * mask / ( 1.0 - p )
300+
- inference: out = input
301+
302+
2. downscale_in_infer, downscale the output at inference
303+
304+
- train: out = input * mask
305+
- inference: out = input * (1.0 - p)
283306
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
284307
285308
Returns:

python/paddle/incubate/nn/layer/fused_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ class FusedMultiHeadAttention(Layer):
3939
attn_dropout_rate (float, optional): The dropout probability used on attention
4040
weights to drop some attention targets for the dropout in attention.
4141
0 for no dropout. Default 0.5.
42-
epsilon (float, optional): he small value added to the variance to prevent
43-
division by zero. Default: 1e-05.
4442
kdim (int, optional): The feature size in key. If None, assumed equal to
4543
`embed_dim`. Default None.
4644
vdim (int, optional): The feature size in value. If None, assumed equal to
@@ -56,6 +54,8 @@ class FusedMultiHeadAttention(Layer):
5654
Default: None, which means the default bias parameter property is used.
5755
If it is set to False, this layer will not have trainable bias parameter.
5856
See usage for details in :code:`ParamAttr`.
57+
epsilon (float, optional): The small value added to the variance to prevent
58+
division by zero. Default: 1e-05.
5959
6060
Examples:
6161

0 commit comments

Comments
 (0)