@@ -11,10 +11,14 @@ limitations under the License. */
1111
1212#pragma once
1313
14- #include " paddle/fluid/operators/fused/attn_bias_add.cu.h"
14+ // #include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
1515#include " paddle/fluid/operators/math/blas.h"
1616#include " paddle/fluid/platform/float16.h"
1717
18+ #include " paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
19+ #include " paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
20+ #include " paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
21+
1822namespace paddle {
1923namespace operators {
2024
@@ -36,8 +40,10 @@ class AttnMatMul {
3640
3741 ~AttnMatMul () {}
3842
39- void ComputeForward (const T* weight_data, const T* input_data,
40- const T* bias_data, T* output_data, T* bias_out_data) {
43+ void ComputeForward (const framework::Tensor* weight,
44+ const framework::Tensor* input,
45+ const framework::Tensor* bias, framework::Tensor* output,
46+ framework::Tensor* bias_out) {
4147 // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
4248 // here: (transa, transb): nt, input * weight.
4349 CBLAS_TRANSPOSE transA = CblasNoTrans;
@@ -54,16 +60,27 @@ class AttnMatMul {
5460 // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
5561 auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
5662 blas.GEMM (transA, transB, bsz_seq_, output_size_, input_size_, alpha,
57- input_data, weight_data , beta, output_data );
63+ input-> data <T>(), weight-> data <T>() , beta, output-> data <T>() );
5864 if (compute_bias_) {
5965 // compute output + bias
60- LaunchBiasAddFwKernel (dev_ctx_, bsz_seq_, output_size_, output_data,
61- bias_data, bias_out_data);
66+ std::vector<const Tensor*> ins;
67+ std::vector<Tensor*> outs;
68+ ins.emplace_back (output);
69+ ins.emplace_back (bias);
70+ outs.emplace_back (bias_out);
71+ int elewise_add_axis = -1 ;
72+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, T>(
73+ dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
6274 }
6375 }
6476
65- void ComputeBackward (const T* input, const T* weight, const T* d_output,
66- T* d_input, T* d_weight, T* d_bias) {
77+ // void ComputeBackward(const T* input, const T* weight, const T* d_output,
78+ // T* d_input, T* d_weight, T* d_bias) {
79+ void ComputeBackward (const framework::Tensor* input,
80+ const framework::Tensor* weight,
81+ const framework::Tensor* d_output,
82+ framework::Tensor* d_input, framework::Tensor* d_weight,
83+ framework::Tensor* d_bias) {
6784 T alpha = static_cast <T>(1.0 );
6885 T beta = static_cast <T>(0.0 );
6986 auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
@@ -81,11 +98,11 @@ class AttnMatMul {
8198
8299 T* dB_input_1_ptr = nullptr ;
83100 T* dB_input_2_ptr = nullptr ;
84- T* dB_output_ptr = d_weight;
101+ T* dB_output_ptr = d_weight-> data <T>() ;
85102
86103 T* dA_input_1_ptr = nullptr ;
87104 T* dA_input_2_ptr = nullptr ;
88- T* dA_output_ptr = d_input;
105+ T* dA_output_ptr = d_input-> data <T>() ;
89106
90107 if (!transA_) {
91108 // fw: gemm-nt
@@ -104,10 +121,10 @@ class AttnMatMul {
104121 dA_n = input_size_;
105122 dA_k = output_size_;
106123
107- blas.GEMM (dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
108- input, beta, dB_output_ptr);
109- blas.GEMM (dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
110- weight, beta, dA_output_ptr);
124+ blas.GEMM (dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
125+ d_output-> data <T>(), input-> data <T>() , beta, dB_output_ptr);
126+ blas.GEMM (dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
127+ d_output-> data <T>(), weight-> data <T>() , beta, dA_output_ptr);
111128 } else { // fw: gemm-nn
112129 // bw: gemm-tn, dB = A^t * dC
113130 dB_transA = CblasTrans;
@@ -123,10 +140,10 @@ class AttnMatMul {
123140 dA_n = input_size_;
124141 dA_k = output_size_;
125142
126- blas.GEMM (dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
127- d_output, beta, dB_output_ptr);
128- blas.GEMM (dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
129- weight, beta, dA_output_ptr);
143+ blas.GEMM (dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
144+ input-> data <T>(), d_output-> data <T>() , beta, dB_output_ptr);
145+ blas.GEMM (dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
146+ d_output-> data <T>(), weight-> data <T>() , beta, dA_output_ptr);
130147 }
131148 } else if (transB_) {
132149 PADDLE_THROW (platform::errors::InvalidArgument (
@@ -138,7 +155,27 @@ class AttnMatMul {
138155 " parameters." ));
139156 }
140157 if (compute_bias_) {
141- LaunchBiasAddBwKernel (dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
158+ // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
159+ const auto input_dims = d_output->dims ();
160+ const auto output_dims = d_bias->dims ();
161+ bool support_case_1 =
162+ (input_dims.size () == 5 && output_dims.size () == 3 &&
163+ (input_dims[2 ] == output_dims[0 ]) &&
164+ (input_dims[3 ] == output_dims[1 ]) &&
165+ (input_dims[4 ] == output_dims[2 ]));
166+ bool support_case_2 =
167+ (input_dims.size () == 3 && output_dims.size () == 1 &&
168+ (input_dims[2 ] == output_dims[0 ]));
169+ if (support_case_1 || support_case_2) {
170+ gpuStream_t stream = dev_ctx_.stream ();
171+ TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0 , 1 },
172+ stream);
173+ } else {
174+ PADDLE_THROW (platform::errors::InvalidArgument (
175+ " Only support reduce when the input dims are [0,1,2,3,4] and "
176+ " output is [2,3,4]"
177+ " or input is [0,1,2] and output is [2]." ));
178+ }
142179 }
143180 }
144181
0 commit comments