Skip to content

Commit c9e3439

Browse files
committed
Fix bias_add bug.
1 parent 6c183a8 commit c9e3439

File tree

3 files changed

+72
-34
lines changed

3 files changed

+72
-34
lines changed

paddle/fluid/operators/fused/attn_bias_add.cu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary(
8787
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
8888
result, arg0, arg1, func);
8989
// store
90-
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
90+
kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
91+
num);
9192
}
9293

9394
// bias add forward impl for "[m, n] + [n] = [m, n]"

paddle/fluid/operators/fused/attn_gemm.h

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ limitations under the License. */
1111

1212
#pragma once
1313

14-
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
14+
// #include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
1515
#include "paddle/fluid/operators/math/blas.h"
1616
#include "paddle/fluid/platform/float16.h"
1717

18+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
19+
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
20+
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
21+
1822
namespace paddle {
1923
namespace operators {
2024

@@ -36,8 +40,10 @@ class AttnMatMul {
3640

3741
~AttnMatMul() {}
3842

39-
void ComputeForward(const T* weight_data, const T* input_data,
40-
const T* bias_data, T* output_data, T* bias_out_data) {
43+
void ComputeForward(const framework::Tensor* weight,
44+
const framework::Tensor* input,
45+
const framework::Tensor* bias, framework::Tensor* output,
46+
framework::Tensor* bias_out) {
4147
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
4248
// here: (transa, transb): nt, input * weight.
4349
CBLAS_TRANSPOSE transA = CblasNoTrans;
@@ -54,16 +60,27 @@ class AttnMatMul {
5460
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
5561
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
5662
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
57-
input_data, weight_data, beta, output_data);
63+
input->data<T>(), weight->data<T>(), beta, output->data<T>());
5864
if (compute_bias_) {
5965
// compute output + bias
60-
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
61-
bias_data, bias_out_data);
66+
std::vector<const Tensor*> ins;
67+
std::vector<Tensor*> outs;
68+
ins.emplace_back(output);
69+
ins.emplace_back(bias);
70+
outs.emplace_back(bias_out);
71+
int elewise_add_axis = -1;
72+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
73+
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
6274
}
6375
}
6476

65-
void ComputeBackward(const T* input, const T* weight, const T* d_output,
66-
T* d_input, T* d_weight, T* d_bias) {
77+
// void ComputeBackward(const T* input, const T* weight, const T* d_output,
78+
// T* d_input, T* d_weight, T* d_bias) {
79+
void ComputeBackward(const framework::Tensor* input,
80+
const framework::Tensor* weight,
81+
const framework::Tensor* d_output,
82+
framework::Tensor* d_input, framework::Tensor* d_weight,
83+
framework::Tensor* d_bias) {
6784
T alpha = static_cast<T>(1.0);
6885
T beta = static_cast<T>(0.0);
6986
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
@@ -81,11 +98,11 @@ class AttnMatMul {
8198

8299
T* dB_input_1_ptr = nullptr;
83100
T* dB_input_2_ptr = nullptr;
84-
T* dB_output_ptr = d_weight;
101+
T* dB_output_ptr = d_weight->data<T>();
85102

86103
T* dA_input_1_ptr = nullptr;
87104
T* dA_input_2_ptr = nullptr;
88-
T* dA_output_ptr = d_input;
105+
T* dA_output_ptr = d_input->data<T>();
89106

90107
if (!transA_) {
91108
// fw: gemm-nt
@@ -104,10 +121,10 @@ class AttnMatMul {
104121
dA_n = input_size_;
105122
dA_k = output_size_;
106123

107-
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
108-
input, beta, dB_output_ptr);
109-
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
110-
weight, beta, dA_output_ptr);
124+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
125+
d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
126+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
127+
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
111128
} else { // fw: gemm-nn
112129
// bw: gemm-tn, dB = A^t * dC
113130
dB_transA = CblasTrans;
@@ -123,10 +140,10 @@ class AttnMatMul {
123140
dA_n = input_size_;
124141
dA_k = output_size_;
125142

126-
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
127-
d_output, beta, dB_output_ptr);
128-
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
129-
weight, beta, dA_output_ptr);
143+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
144+
input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
145+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
146+
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
130147
}
131148
} else if (transB_) {
132149
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -138,7 +155,27 @@ class AttnMatMul {
138155
"parameters."));
139156
}
140157
if (compute_bias_) {
141-
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
158+
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
159+
const auto input_dims = d_output->dims();
160+
const auto output_dims = d_bias->dims();
161+
bool support_case_1 =
162+
(input_dims.size() == 5 && output_dims.size() == 3 &&
163+
(input_dims[2] == output_dims[0]) &&
164+
(input_dims[3] == output_dims[1]) &&
165+
(input_dims[4] == output_dims[2]));
166+
bool support_case_2 =
167+
(input_dims.size() == 3 && output_dims.size() == 1 &&
168+
(input_dims[2] == output_dims[0]));
169+
if (support_case_1 || support_case_2) {
170+
gpuStream_t stream = dev_ctx_.stream();
171+
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
172+
stream);
173+
} else {
174+
PADDLE_THROW(platform::errors::InvalidArgument(
175+
"Only support reduce when the input dims are [0,1,2,3,4] and "
176+
"output is [2,3,4]"
177+
"or input is [0,1,2] and output is [2]."));
178+
}
142179
}
143180
}
144181

paddle/fluid/operators/fused/fused_attention_op.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
170170

171171
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
172172
ln_out_data, ln_mean_data, ln_var_data);
173-
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
174-
qkv_out_data, qkv_bias_out_data);
173+
qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
174+
qkv_bias_out);
175175
} else {
176-
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
177-
qkv_out_data, qkv_bias_out_data);
176+
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
177+
qkv_bias_out);
178178
}
179179
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
180180
qk_out, src_mask_out, softmax_out,
@@ -184,8 +184,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
184184
// fmha_out: [batch_size, seq_len, num_head, head_dim]
185185
// weight: [embed_dim, embed_dim]
186186
// out_linear_out: [batch_size, seq_len, embed_dim]
187-
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
188-
nullptr, out_linear_out_data, nullptr);
187+
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
188+
out_linear_out, nullptr);
189189
if (pre_layer_norm) {
190190
// output = (residual + dropout(input + bias))
191191
fused_dropout_layernorm_helper.ResidualDropoutBias(
@@ -401,9 +401,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
401401
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
402402
}
403403

404-
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
405-
d_out_linear_out_data, d_fmha_out_data,
406-
d_out_linear_weight_data, nullptr);
404+
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
405+
d_out_linear_out, d_fmha_out,
406+
d_out_linear_weight, nullptr);
407+
407408
fmha_ref_compute.ComputeBackward(
408409
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
409410
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
@@ -432,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
432433
(d_ln_bias == nullptr ? nullptr
433434
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
434435

435-
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
436-
d_qkv_bias_out_data, d_ln_out_data,
437-
d_qkv_weight_data, d_qkv_bias_data);
436+
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
437+
d_qkv_weight, d_qkv_bias);
438438
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
439439
ln_mean_data, ln_var_data, d_x_data,
440440
d_ln_scale_data, d_ln_bias_data);
441441
} else {
442-
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
443-
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
442+
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
443+
d_qkv_weight, d_qkv_bias);
444444
}
445445
// gradient accumulation
446446
std::vector<const Tensor *> ins;

0 commit comments

Comments
 (0)