From f5eee9f9a95f20ce57776036b994575112b0da32 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 05:01:39 +0000 Subject: [PATCH 01/29] Add fused_attention_op: add impl wrappers. --- .../elementwise/elementwise_op_impl.cu.h | 3 +- .../operators/fused/attention_layer_norm.h | 2 +- .../fluid/operators/fused/attn_bias_add.cu.h | 6 +- paddle/fluid/operators/fused/attn_gemm.h | 159 +++++++++ paddle/fluid/operators/fused/fmha_ref.h | 324 ++++++++++++++++++ paddle/fluid/operators/layer_norm_kernel.cu.h | 1 - 6 files changed, 487 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/fused/attn_gemm.h create mode 100644 paddle/fluid/operators/fused/fmha_ref.h diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 81dff9473074f6..e4074cc7d7d600 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller { template struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, InT **args, OutT *result) { + __device__ inline void operator()(Functor func, InT (*args)[VecSize], + OutT *result) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h index d234a0f08531f5..43491a9faf18cf 100644 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ b/paddle/fluid/operators/fused/attention_layer_norm.h @@ -50,7 +50,7 @@ class AttnLayerNorm { } } - void ComputeBackward(const T* x_data, const T* y_data, + void ComputeBackward(const T* x_data, const T* d_y_data, const LayerNormParamType* scale_data, const LayerNormParamType* mean_data, const LayerNormParamType* var_data, T* d_x_data, diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index a8bd35a1b7309a..fa3eb19b29995a 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #define LAUNCH_BOUNDS(BlockDim) #endif +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" @@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType; template using ReduceParamType = typename CudnnDataType::BatchNormParamType; -template -struct AddFunctor { - inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; } -}; - template __global__ void BroadcastKernelBinary( diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h new file mode 100644 index 00000000000000..a2001d0a814922 --- /dev/null +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/attn_bias_add.cu.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +// support gemm-nt and gemm-nn, which is used in fused_attention_op. +template +class AttnMatMul { + public: + // (m, n, k) = bsz_seq, output_size, input_size + AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA, + bool transB, int bsz_seq, int output_size, int input_size, + bool compute_bias) + : dev_ctx_(dev_ctx), + transA_(transA), + transB_(transB), + bsz_seq_(bsz_seq), + output_size_(output_size), + input_size_(input_size), + compute_bias_(compute_bias) {} + + ~AttnMatMul() {} + + void ComputeForward(const T* weight_data, const T* input_data, + const T* bias_data, T* output_data, T* bias_out_data) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // here: (transa, transb): nt, input * weight. + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + if (transA_) { + transA = CblasTrans; + } + if (transB_) { + transB = CblasTrans; + } + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = math::GetBlas(dev_ctx_); + blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, + input_data, weight_data, beta, output_data); + if (compute_bias_) { + // compute output + bias + LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data, + bias_data, bias_out_data); + } + } + + void ComputeBackward(const T* input, const T* weight, const T* d_output, + T* d_input, T* d_weight, T* d_bias) { + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + auto blas = math::GetBlas(dev_ctx_); + + CBLAS_TRANSPOSE dB_transA = CblasNoTrans; + CBLAS_TRANSPOSE dB_transB = CblasNoTrans; + CBLAS_TRANSPOSE dA_transA = CblasNoTrans; + CBLAS_TRANSPOSE dA_transB = CblasNoTrans; + int dB_m = 1; + int dB_n = 1; + int dB_k = 1; + int dA_m = 1; + int dA_n = 1; + int dA_k = 1; + + T* dB_input_1_ptr = nullptr; + T* dB_input_2_ptr = nullptr; + T* dB_output_ptr = d_weight; + + T* dA_input_1_ptr = nullptr; + T* dA_input_2_ptr = nullptr; + T* dA_output_ptr = d_input; + + if (!transA_) { + // fw: gemm-nt + if (transB_) { + // bw: gemm-tn, dB = (dC)^t * A + dB_transA = CblasTrans; + dB_transB = CblasNoTrans; + dB_m = output_size_; + dB_n = input_size_; + dB_k = bsz_seq_; + + // bw: gemm-nn, dA = dC * B + dA_transA = CblasNoTrans; + dA_transB = CblasNoTrans; + dA_m = bsz_seq_; + dA_n = input_size_; + dA_k = output_size_; + + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output, + input, beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, + weight, beta, dA_output_ptr); + } else { // fw: gemm-nn + // bw: gemm-tn, dB = A^t * dC + dB_transA = CblasTrans; + dB_transB = CblasNoTrans; + dB_m = input_size_; + dB_n = output_size_; + dB_k = bsz_seq_; + + // bw: gemm-nt, dA = dC * B^t + dA_transA = CblasNoTrans; + dA_transB = CblasTrans; + dA_m = bsz_seq_; + dA_n = input_size_; + dA_k = output_size_; + + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input, + d_output, beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, + weight, beta, dA_output_ptr); + } + } else if (transB_) { + PADDLE_THROW(platform::errors::InvalidArgument( + "AttnMatMul wrapper do not support (transA=T, transB=T)" + "parameters.")); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "AttnMatMul wrapper do not support (transA=T, transB=N)" + "parameters.")); + } + if (compute_bias_) { + LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); + } + } + + private: + const platform::CUDADeviceContext& dev_ctx_; + + bool transA_; + bool transB_; + + int bsz_seq_; + int output_size_; + int input_size_; + + int compute_bias_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h new file mode 100644 index 00000000000000..bef0052a00d6b2 --- /dev/null +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -0,0 +1,324 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" +#include "paddle/fluid/operators/transpose_op.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class AttnDropoutParam { + public: + AttnDropoutParam() { + is_test_ = false; + dropout_implementation_ = "downgrade_in_infer"; + dropout_prob_ = 0.5; + is_upscale_in_train_ = false; + is_fix_seed_ = false; + seed_val_ = 0; + seed_ = nullptr; + } + AttnDropoutParam(bool is_test, const std::string dropout_implementation, + float dropout_prob, bool is_upscale_in_train, + bool is_fix_seed, int seed_val, const Tensor* seed) { + is_test_ = is_test; + dropout_implementation_ = dropout_implementation; + dropout_prob_ = dropout_prob; + is_upscale_in_train_ = is_upscale_in_train; + is_fix_seed_ = is_fix_seed; + seed_val_ = seed_val; + seed_ = seed; + } + bool is_test_; + std::string dropout_implementation_; + float dropout_prob_; + bool is_upscale_in_train_; + bool is_fix_seed_; + int seed_val_; + const Tensor* seed_; +}; + +template +class FMHARef { + public: + FMHARef(const platform::CUDADeviceContext& dev_ctx, int64_t batch_size, + int64_t seq_len, int64_t num_head, int64_t head_dim, + AttnDropoutParam param) + : dev_ctx_(dev_ctx), + batch_size_(batch_size), + seq_len_(seq_len), + num_head_(num_head), + head_dim_(head_dim), + dropout_param_(param) {} + + ~FMHARef() {} + + void ComputeForward(const Tensor& qkv_input_tensor, + const Tensor& src_mask_tensor, + Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, + Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, + Tensor* dropout_mask_out_tensor, + Tensor* dropout_out_tensor, Tensor* qktv_out_tensor, + Tensor* fmha_out_tensor) { + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 1, 3, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + int ndims = 5; + std::vector perm_1 = {2, 0, 3, 1, 4}; + TransposeGPUKernelDriver(dev_ctx_, ndims, qkv_input_tensor, perm_1, + transpose_2_out_tensor); + + T* qkv_data = transpose_2_out_tensor->data(); + T* qk_out_data = qk_out_tensor->data(); + T* qktv_out_data = qktv_out_tensor->data(); + T* softmax_out_data = softmax_out_tensor->data(); + T* dropout_out_data = dropout_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + int k_size = q_size; + T* q_ptr = qkv_data; + T* k_ptr = q_ptr + q_size; + T* v_ptr = k_ptr + k_size; + + // q*k^t, batched_gemm + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + auto blas = math::GetBlas(dev_ctx_); + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = seq_len_; + int gemm_k = head_dim_; + T alpha = static_cast(1.0 / sqrt(head_dim_)); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, + k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, + stride_b); + + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(&src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + int softmax_axis = -1; + if (&src_mask_tensor != nullptr) { + LaunchElementwiseCudaKernel( + dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, + softmax_axis, softmax_out_tensor); + } else { + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out_tensor, softmax_axis, + softmax_out_tensor); + } + + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + alpha = static_cast(1.0); + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + + if (dropout_param_.dropout_prob_) { + DropoutFwGPUKernelDriver( + dev_ctx_, dropout_param_.is_test_, + static_cast( + dropout_param_.dropout_implementation_), + dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, + dropout_param_.is_fix_seed_, dropout_param_.seed_val_, + static_cast(*softmax_out_tensor), dropout_param_.seed_, + dropout_mask_out_tensor, dropout_out_tensor); + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + dropout_out_data, v_ptr, beta, qktv_out_data, + gemm_batch_size, stride_a, stride_b); + } else { + // softmax_out * v, batched_gemm + // output shape: [batch_size, num_heads, seq_len, head_dim] + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + softmax_out_data, v_ptr, beta, qktv_out_data, + gemm_batch_size, stride_a, stride_b); + } + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + std::vector perm_3 = {0, 2, 1, 3}; + ndims = 4; + TransposeGPUKernelDriver(dev_ctx_, ndims, *qktv_out_tensor, perm_3, + fmha_out_tensor); + } + + void ComputeBackward( + const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, + const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, + const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, + const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, + Tensor* qktv_out_grad_tensor, Tensor* dropout_out_grad_tensor, + Tensor* softmax_out_grad_tensor, Tensor* src_mask_out_grad_tensor, + Tensor* qk_out_grad_tensor, Tensor* transpose_2_out_grad_tensor, + Tensor* src_mask_grad_tensor, Tensor* qkv_input_grad_tensor) { + auto blas = math::GetBlas(dev_ctx_); + int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + int k_size = q_size; + int softmax_axis = -1; + + T* qkv_grad_data = transpose_2_out_grad_tensor->data(); + T* q_grad_ptr = qkv_grad_data; + T* k_grad_ptr = q_grad_ptr + q_size; + T* v_grad_ptr = k_grad_ptr + k_size; + const T* qkv_data = transpose_2_out_tensor.data(); + const T* q_ptr = qkv_data; + const T* k_ptr = q_ptr + q_size; + const T* v_ptr = k_ptr + k_size; + + const T* softmax_out_data = softmax_out_tensor.data(); + T* softmax_out_grad_data = softmax_out_grad_tensor->data(); + const T* dropout_out_data = dropout_out_tensor.data(); + T* dropout_out_grad_data = dropout_out_grad_tensor->data(); + T* qktv_out_grad_data = qktv_out_grad_tensor->data(); + + // transpose bw + int ndims = 4; + std::vector perm_3 = {0, 2, 1, 3}; + TransposeGPUKernelDriver(dev_ctx_, ndims, fmha_out_grad_tensor, perm_3, + qktv_out_grad_tensor); + + // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = + // qktv_out_data(out) + CBLAS_TRANSPOSE transA = CblasTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = head_dim_; + int gemm_k = seq_len_; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + // bw: dy = x^t * dout + if (dropout_param_.dropout_prob_) { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + dropout_out_data, qktv_out_grad_data, beta, v_grad_ptr, + gemm_batch_size, stride_a, stride_b); + } else { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + softmax_out_data, qktv_out_grad_data, beta, v_grad_ptr, + gemm_batch_size, stride_a, stride_b); + } + // bw: dx = dout * y^t + transA = CblasNoTrans; + transB = CblasTrans; + gemm_m = seq_len_; + gemm_n = seq_len_; + gemm_k = head_dim_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + if (dropout_param_.dropout_prob_) { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qktv_out_grad_data, v_ptr, beta, dropout_out_grad_data, + gemm_batch_size, stride_a, stride_b); + } else { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qktv_out_grad_data, v_ptr, beta, softmax_out_grad_data, + gemm_batch_size, stride_a, stride_b); + } + // dropout bw + if (dropout_param_.dropout_prob_) { + DropoutGradGPUKernelDriver( + dev_ctx_, static_cast( + dropout_param_.dropout_implementation_), + dropout_param_.dropout_prob_, + static_cast(*dropout_out_grad_tensor), + dropout_mask_out_tensor, softmax_out_grad_tensor->numel(), + softmax_out_grad_tensor); + } + + if (&src_mask_tensor != nullptr) { + SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, + *softmax_out_grad_tensor, softmax_axis, + src_mask_out_grad_tensor); + + // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + + // src_mask + // Special case when dy is not needed and dx doesn't reduce + if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr && + qk_out_tensor.dims() == src_mask_out_tensor.dims()) { + VLOG(4) << "Special case when dy is not needed and dx doesn't " + "reduce"; + framework::TensorCopy(*src_mask_out_grad_tensor, dev_ctx_.GetPlace(), + dev_ctx_, qk_out_grad_tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only used for the backward elementwise_add op when" + "dy is not needed and dx is not reduce")); + return; + } + + } else { + SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, + *softmax_out_grad_tensor, softmax_axis, + qk_out_grad_tensor); + } + + T* qk_out_grad_data = qk_out_grad_tensor->data(); + alpha = static_cast(1.0 / sqrt(head_dim_)); + // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out + // bw: dy (seq_len * head_dim) = (dout)^t * x + transA = CblasTrans; + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size, + stride_a, stride_b); + // dx (seq_len * head_dim) = dout * y + transA = CblasNoTrans; + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qk_out_grad_data, k_ptr, beta, q_grad_ptr, gemm_batch_size, + stride_a, stride_b); + + // transpose bw + ndims = 5; + std::vector perm_1 = {1, 3, 0, 2, 4}; + TransposeGPUKernelDriver(dev_ctx_, ndims, *transpose_2_out_grad_tensor, + perm_1, qkv_input_grad_tensor); + } + + private: + const platform::CUDADeviceContext& dev_ctx_; + + int64_t batch_size_; + int64_t seq_len_; + int64_t num_head_; + int64_t head_dim_; + + AttnDropoutParam dropout_param_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 06c1eaf881626c..4280c86ca99ab8 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -35,7 +35,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template From e16e3b38e2d7258461fbb06a0c97aa2430097b05 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 07:04:48 +0000 Subject: [PATCH 02/29] Add fused_attention_op: forward. --- .../operators/fused/fused_attention_op.cc | 620 ++++++++++++++++++ .../operators/fused/fused_attention_op.cu | 524 +++++++++++++++ .../operators/fused/fused_dropout_helper.h | 298 +++++++++ .../unittests/test_fused_attention_op.py | 303 +++++++++ 4 files changed, 1745 insertions(+) create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cc create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cu create mode 100644 paddle/fluid/operators/fused/fused_dropout_helper.h create mode 100644 python/paddle/fluid/tests/unittests/test_fused_attention_op.py diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc new file mode 100644 index 00000000000000..1c3db42a6177b3 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -0,0 +1,620 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_attention_op.h" + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedAttentionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // std::cout << "i am in op infershape\n"; + + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + "FusedAttentionOp"); + + // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", + "AttnDropoutMaskOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"), "Output", "AttnDropoutOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", + "FusedAttentionOp"); +#if 1 + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", + "FusedAttentionOp"); +#endif + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("QKVW"); + // auto qkv_bias_dim = ctx->GetInputDim("QKVBias"); + // auto src_mask_dim = ctx->GetInputDim("SrcMask"); + // std::cout << "x_dim = " << x_dim << std::endl; + // std::cout << "qkv_weight_dim = " << y_dim << std::endl; + // std::cout << "qkv_bias_dim = " << qkv_bias_dim << std::endl; + // // src_mask_dim = 32, 16, 128, 128 + // std::cout << "src_mask_dim = " << src_mask_dim << std::endl; + + PADDLE_ENFORCE_EQ(x_dim.size(), 3, + platform::errors::InvalidArgument( + "The dimensions of QKV_input must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + + PADDLE_ENFORCE_EQ(y_dim.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of QKV_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + + // limin-todo: polish the expression. + PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "must be equal. But received: the shape " + "of input X = [%s], and the shape of " + "input Y = [%s]", + x_dim, y_dim)); + + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + // [batch_size, seq_len, 3, num_head, head_size] + ctx->SetOutputDim("QKVOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + // limin-todo: [3, batch_size, seq_len, num_head, head_size] + // check shape: [3, batch_size, num_head, seq_len, head_size] + ctx->SetOutputDim("TransposeOut2", + {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // check shape: batch, num_head, seq_len, seq_len + ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // the same as QKOut's shape. + ctx->SetOutputDim("AttnDropoutOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + if (ctx->Attrs().Get("is_test1") == false) { + ctx->SetOutputDim("AttnDropoutMaskOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + } + ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // check shape [batch_size, num_heads, seq_len, head_dim] + ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // check shape, [batch_size, seq_len, number of heads*head size] + ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); + +#if 1 + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + if (ctx->Attrs().Get("is_test") == false) { + ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); + } + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); +#endif + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "(optional) Scale is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddInput("LnBias", + "(optional) Bias is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddInput("QKVW", "The qkv weight tensor."); + AddInput("QKVBias", "The qkv bias tensor."); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor."); +#if 1 + AddInput("Ln2Scale", + "(optional) Scale is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddInput("Ln2Bias", + "(optional) Bias is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); +#endif +#if 1 +// todo: +// AddInput("Seed", +// "The seed of dropout op, it has higher priority than the attr " +// "fix_seed and seed") +// .AsDispensable(); +#endif + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); + + AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); + AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); + + // fma + AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); + AddOutput("QKOut", "Result in fmha.").AsIntermediate(); + AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); + AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); + AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); + + AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); + +#if 1 + AddOutput("DropoutMaskOut", "The random sampled dropout mask.") + .AsIntermediate(); + AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("Ln2Variance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("BiasDropoutResidualOut", + "Result of residual + dropout(src + bias).") + .AsIntermediate(); +#endif + + AddOutput("Y", "Result after attention."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default false].") + .SetDefault(false); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + // AddAttr("begin_norm_axis", + // "the axis of `begin_norm_axis ... Rank(X) - 1` will be " + // "normalized. `begin_norm_axis` splits the tensor(`X`) to a " + // "matrix [N,H]. [default 1].") + // .SetDefault(1) + // .AddCustomChecker([](const int &begin_norm_axis) { + // PADDLE_ENFORCE_GT(begin_norm_axis, 0, + // platform::errors::InvalidArgument( + // "'begin_norm_axis' in Op(LayerNorm) should + // be" + // "greater than zero. But received [%d].", + // begin_norm_axis)); + // }); + + // for dropout in fmha. + AddAttr("attn_dropout_prob", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ( + drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'attn_dropout_prob' must be between 0.0 and 1.0.")); + }); + AddAttr("is_test1", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("fix_seed1", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("seed1", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation1", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * (1.0 - dropout_prob)" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("upscale_in_train") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + +#if 1 + AddAttr("dropout_prob", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'dropout_prob' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * (1.0 - dropout_prob)" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("ln2epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &ln2epsilon) { + PADDLE_ENFORCE_EQ(ln2epsilon >= 0.0f && ln2epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' of the second LayerNorm in Fused " + "attention op should be between" + "0.0 and 0.001, But received [%s].", + ln2epsilon)); + }); +#endif + + AddComment(R"DOC( +Fused attention: +if (pre_layernorm) + layer_norm; +qkv+bias_add; +fmha; +out_linear; +bias_add + dropout + residual + layer_norm; +)DOC"); + } +}; + +class FusedAttentionGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +// auto x_dim = ctx->GetInputDim("X"); +// auto y_dim = ctx->GetInputDim("QKVW"); +// std::cout << "x_dim = " << x_dim << std::endl; +// std::cout << "y_dim = " << y_dim << std::endl; +// int batch_size = x_dim[0]; +// int seq_len = x_dim[1]; +// int embed_dim = x_dim[2]; +// std::cout << "batch_size, seq_len, embed_dim= " << batch_size << ", " << +// seq_len << ", " << embed_dim << std::endl; + +#if 1 + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_test"), false, + platform::errors::InvalidArgument( + "GradOp is only callable when is_test is false")); + + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedAttentionGrad"); + if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), + ctx->GetInputDim("Ln2Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), + ctx->GetInputDim("Ln2Bias")); + } +#endif + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", + "FusedAttentionGrad"); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", + "FusedAttentionGrad"); + } + OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + "FusedAttentionGrad"); + + if (ctx->HasOutput(framework::GradVarName("LnScale"))) { + ctx->SetOutputDim(framework::GradVarName("LnScale"), + ctx->GetInputDim("LnScale")); + } + if (ctx->HasOutput(framework::GradVarName("LnBias"))) { + ctx->SetOutputDim(framework::GradVarName("LnBias"), + ctx->GetInputDim("LnBias")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), + ctx->GetInputDim("OutLinearBias")); + ctx->SetOutputDim(framework::GradVarName("OutLinearW"), + ctx->GetInputDim("OutLinearW")); + ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); + ctx->SetOutputDim(framework::GradVarName("QKVBias"), + ctx->GetInputDim("QKVBias")); + + ctx->SetOutputDim(framework::GradVarName("LnOut"), + ctx->GetInputDim("LnOut")); + ctx->SetOutputDim(framework::GradVarName("FMHAOut"), + ctx->GetInputDim("FMHAOut")); + ctx->SetOutputDim(framework::GradVarName("QKTVOut"), + ctx->GetInputDim("QKTVOut")); + ctx->SetOutputDim(framework::GradVarName("TransposeOut2"), + ctx->GetInputDim("TransposeOut2")); + ctx->SetOutputDim(framework::GradVarName("QKOut"), + ctx->GetInputDim("QKOut")); + ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), + ctx->GetInputDim("SoftmaxOut")); + ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), + ctx->GetInputDim("AttnDropoutOut")); + ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), + ctx->GetInputDim("SrcMaskOut")); + ctx->SetOutputDim(framework::GradVarName("QKVOut"), + ctx->GetInputDim("QKVOut")); + ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), + ctx->GetInputDim("QKVBiasOut")); +#if 1 + ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), + ctx->GetInputDim("OutLinearOut")); + // ctx->SetOutputDim(framework::GradVarName("DropoutMaskOut"), + // ctx->GetInputDim("DropoutMaskOut")); + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); +#endif + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("fused_attention_grad"); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + // inputs x, parameters and their grad. + op->SetInput("X", this->Input("X")); + op->SetInput("QKVW", this->Input("QKVW")); + op->SetInput("QKVBias", this->Input("QKVBias")); + op->SetInput("SrcMask", this->Input("SrcMask")); + op->SetInput("OutLinearW", this->Input("OutLinearW")); + op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + if (this->HasInput("LnScale")) { + op->SetInput("LnScale", this->Input("LnScale")); + op->SetOutput(framework::GradVarName("LnScale"), + this->InputGrad("LnScale")); + } + if (this->HasInput("LnBias")) { + op->SetInput("LnBias", this->Input("LnBias")); + op->SetOutput(framework::GradVarName("LnBias"), + this->InputGrad("LnBias")); + } +#if 1 + if (this->HasInput("Ln2Scale")) { + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + } + if (this->HasInput("Ln2Bias")) { + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + } +#endif + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); + op->SetOutput(framework::GradVarName("QKVBias"), + this->InputGrad("QKVBias")); + op->SetOutput(framework::GradVarName("OutLinearBias"), + this->InputGrad("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearW"), + this->InputGrad("OutLinearW")); + + // use forward's output as bw's input. + op->SetInput("LnOut", this->Output("LnOut")); + op->SetInput("LnMean", this->Output("LnMean")); + op->SetInput("LnVariance", this->Output("LnVariance")); + op->SetInput("QKVOut", this->Output("QKVOut")); + op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetInput("TransposeOut2", this->Output("TransposeOut2")); + op->SetInput("QKOut", this->Output("QKOut")); + op->SetInput("QKTVOut", this->Output("QKTVOut")); + op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); + op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); + op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); + op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetInput("FMHAOut", this->Output("FMHAOut")); + op->SetInput("OutLinearOut", this->Output("OutLinearOut")); + +#if 1 + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); + op->SetInput("BiasDropoutResidualOut", + this->Output("BiasDropoutResidualOut")); +#endif + // op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetInput("QKVOut", this->Output("QKVOut")); + + // bw's output: dinput + op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); + op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); + op->SetOutput(framework::GradVarName("QKVBiasOut"), + this->OutputGrad("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKTVOut"), + this->OutputGrad("QKTVOut")); + op->SetOutput(framework::GradVarName("TransposeOut2"), + this->OutputGrad("TransposeOut2")); + op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut")); + op->SetOutput(framework::GradVarName("SoftmaxOut"), + this->OutputGrad("SoftmaxOut")); + op->SetOutput(framework::GradVarName("AttnDropoutOut"), + this->OutputGrad("AttnDropoutOut")); + op->SetOutput(framework::GradVarName("SrcMaskOut"), + this->OutputGrad("SrcMaskOut")); + op->SetOutput(framework::GradVarName("FMHAOut"), + this->OutputGrad("FMHAOut")); +#if 1 + // op->SetOutput(framework::GradVarName("DropoutMaskOut"), + // this->OutputGrad("DropoutMaskOut")); + op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), + this->OutputGrad("BiasDropoutResidualOut")); +#endif + op->SetOutput(framework::GradVarName("OutLinearOut"), + this->OutputGrad("OutLinearOut")); + // op->SetOutput(framework::GradVarName("OutLinearBiasOut"), + // this->OutputGrad("OutLinearBiasOut")); + + op->SetAttrMap(this->Attrs()); + } +}; + +// DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseAddLayerNormGradNoNeedBufferVarInferer, +// "Bias"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, + ops::FusedAttentionOpMaker, + ops::FusedAttentionGradOpMaker, + ops::FusedAttentionGradOpMaker); + +REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); +// REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp, +// ops::FusedAttentionGradNoNeedBufferVarInferer); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu new file mode 100644 index 00000000000000..83a9287bf2e232 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -0,0 +1,524 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/cuda_device_function.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif + +#include +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +#include "paddle/fluid/operators/fused/fused_attention_op.h" + +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FusedAttentionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto *input_x = ctx.Input("X"); + + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + auto *ln_out = ctx.Output("LnOut"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto *qkv_weight = ctx.Input("QKVW"); + auto *qkv_bias = ctx.Input("QKVBias"); + auto *qkv_out = ctx.Output("QKVOut"); + auto *qkv_bias_out = ctx.Output("QKVBiasOut"); + + // FMHA-ref: + auto *src_mask = ctx.Input("SrcMask"); + auto *transpose_out_2 = ctx.Output("TransposeOut2"); + auto *qk_out = ctx.Output("QKOut"); + auto *qktv_out = ctx.Output("QKTVOut"); + auto *softmax_out = ctx.Output("SoftmaxOut"); + auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); + auto *src_mask_out = ctx.Output("SrcMaskOut"); + auto *fmha_out = ctx.Output("FMHAOut"); + + // out_linear + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *out_linear_out = ctx.Output("OutLinearOut"); + +// bias+dropout+residual+layernorm +#if 1 + auto *ln_scale_2 = ctx.Input("Ln2Scale"); + auto *ln_bias_2 = ctx.Input("Ln2Bias"); + auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Output("BiasDropoutResidualOut"); + auto *ln_mean_2 = ctx.Output("Ln2Mean"); + auto *ln_var_2 = ctx.Output("Ln2Variance"); + const float ln2epsilon = ctx.Attr("ln2epsilon"); +#endif + +#if 1 + float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); + std::cout << "limin: attn_dropout_prob = " << attn_dropout_prob + << std::endl; + bool is_test_1 = ctx.Attr("is_test1"); + auto &dropout_implementation_1 = + ctx.Attr("dropout_implementation1"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("fix_seed1"); + int seed_val_1 = ctx.Attr("seed1"); +#endif + + // final output. + auto *out = ctx.Output("Y"); + + // get data ptr for qkv part. + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + auto *x_data = input_x->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + + auto *qkv_weight_data = qkv_weight->data(); + auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); + auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); + + // get data ptr for FMHA. + auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); + auto *transpose_out_2_data = + transpose_out_2->mutable_data(ctx.GetPlace()); + auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); + auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); + auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); + auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_mask_out_data = + attn_dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_out_data = + attn_dropout_out->mutable_data(ctx.GetPlace()); + auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); + + // get data ptr for out_linear. + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); + +// get data ptr for bias+dropout+residual+layernorm +#if 1 + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *dropout_mask_out_data = + dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); +#endif + auto *final_out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool transA = false; + bool transB = true; + bool compute_bias = true; + auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), + epsilon, bsz_seq, dim_embed); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); + + // AttnDropoutParam(bool is_test, const std::string dropout_implementation, + // float dropout_prob, bool is_upscale_in_train, + // bool is_fix_seed, int seed_val, const Tensor* seed) { + AttnDropoutParam attn_dropout_param( + is_test_1, dropout_implementation_1, attn_dropout_prob, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + auto fmha_ref_compute = + FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, + dim_head, attn_dropout_param); + // out_linear + output_size = hidden_size; + transA = false; + transB = false; + compute_bias = false; + auto out_linear_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); +#if 1 + DropoutParam dropout_param2(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln2epsilon); +#endif + + // compute + if (pre_layer_norm) { + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, + ln_out_data, ln_mean_data, ln_var_data); + qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } else { + qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } + // compute FMHA + fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); +// fmha_out: [batch_size, seq_len, num_head, head_dim] +// weight: [1024, 1024], [embed_dim, embed_dim] +// out_linear_out: [batch_size, seq_len, embed_dim] +#if 1 + out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, + nullptr, out_linear_out_data, nullptr); +#endif +#if 1 + // out = layernorm(residual + dropout(src + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); +#endif + } +}; + +template +class FusedAttentionGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); +#if 1 + const float ln2epsilon = ctx.Attr("ln2epsilon"); +#endif + +#if 1 + float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); + bool is_test_1 = ctx.Attr("is_test1"); + auto &dropout_implementation_1 = + ctx.Attr("dropout_implementation1"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("fix_seed1"); + int seed_val_1 = ctx.Attr("seed1"); +#endif + + // get inputs. + auto *d_y = ctx.Input(framework::GradVarName("Y")); + auto *d_y_data = d_y->data(); + + // fw input + auto *input_x = ctx.Input("X"); + auto *ln_scale = ctx.Input("LnScale"); +#if 1 + auto *ln_2_scale = ctx.Input("Ln2Scale"); +#endif + auto *x_data = input_x->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); +#if 1 + auto *ln_2_scale_data = + (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); +#endif + // fw parameters. + auto *src_mask = ctx.Input("SrcMask"); + auto *qkv_weight = ctx.Input("QKVW"); + auto *qkv_bias = ctx.Input("QKVBias"); + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); + auto *qkv_weight_data = qkv_weight->data(); + auto *qkv_bias_data = qkv_bias->data(); + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + + // fw output + auto *ln_mean = ctx.Input("LnMean"); + auto *ln_var = ctx.Input("LnVariance"); + auto *ln_out = ctx.Input("LnOut"); + auto *fmha_out = ctx.Input("FMHAOut"); + auto *transpose_out_2 = ctx.Input("TransposeOut2"); + auto *qk_out = ctx.Input("QKOut"); + auto *qktv_out = ctx.Input("QKTVOut"); + auto *softmax_out = ctx.Input("SoftmaxOut"); + auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); + auto *src_mask_out = ctx.Input("SrcMaskOut"); + auto *out_linear_out = ctx.Input("OutLinearOut"); +#if 1 + auto *ln_2_mean = ctx.Input("Ln2Mean"); + auto *ln_2_var = ctx.Input("Ln2Variance"); + auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Input("BiasDropoutResidualOut"); +#endif + auto *ln_mean_data = ln_mean->data(); + auto *ln_var_data = ln_var->data(); + auto *ln_out_data = ln_out->data(); + auto *fmha_out_data = fmha_out->data(); + auto *transpose_out_2_data = transpose_out_2->data(); + auto *qk_out_data = qk_out->data(); + auto *qktv_out_data = qktv_out->data(); + auto *softmax_out_data = softmax_out->data(); + auto *src_mask_out_data = src_mask_out->data(); + auto *out_linear_out_data = out_linear_out->data(); +#if 1 + auto *ln_2_mean_data = ln_2_mean->data(); + auto *ln_2_var_data = ln_2_var->data(); + auto *dropout_mask_out_data = dropout_mask_out->data(); + auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); +#endif + + // bw output's grad + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); + auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); + auto *d_qkv_bias_out = + ctx.Output(framework::GradVarName("QKVBiasOut")); + auto *d_qktv_out = ctx.Output(framework::GradVarName("QKTVOut")); + auto *d_transpose_out_2 = + ctx.Output(framework::GradVarName("TransposeOut2")); + auto *d_qk_out = ctx.Output(framework::GradVarName("QKOut")); + auto *d_softmax_out = + ctx.Output(framework::GradVarName("SoftmaxOut")); + auto *d_attn_dropout_out = + ctx.Output(framework::GradVarName("AttnDropoutOut")); + auto *d_src_mask_out = + ctx.Output(framework::GradVarName("SrcMaskOut")); + auto *d_fmha_out = ctx.Output(framework::GradVarName("FMHAOut")); + auto *d_out_linear_out = + ctx.Output(framework::GradVarName("OutLinearOut")); +#if 1 + // auto *d_dropout_mask_out = + // ctx.Output(framework::GradVarName("DropoutMaskOut")); + auto *d_bias_dropout_residual_out = + ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); +#endif + auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); + auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); + auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); + auto *d_transpose_out_2_data = + d_transpose_out_2->mutable_data(ctx.GetPlace()); + auto *d_qk_out_data = d_qk_out->mutable_data(ctx.GetPlace()); + auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); + auto *d_attn_dropout_out_data = + d_attn_dropout_out->mutable_data(ctx.GetPlace()); + auto *d_src_mask_out_data = d_src_mask_out->mutable_data(ctx.GetPlace()); + auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); + auto *d_out_linear_out_data = + d_out_linear_out->mutable_data(ctx.GetPlace()); +#if 1 + // auto *d_dropout_mask_out_data = + // d_dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *d_bias_dropout_residual_out_data = + d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); +#endif + + // bw parameter's grad + auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); + auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); + auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); + auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); + auto *d_out_linear_weight = + ctx.Output(framework::GradVarName("OutLinearW")); + auto *d_out_linear_bias = + ctx.Output(framework::GradVarName("OutLinearBias")); +#if 1 + auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); + auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); +#endif + auto *d_ln_scale_data = + (d_ln_scale == nullptr ? nullptr + : d_ln_scale->mutable_data(ctx.GetPlace())); + auto *d_ln_bias_data = + (d_ln_bias == nullptr ? nullptr + : d_ln_bias->mutable_data(ctx.GetPlace())); + auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); + auto *d_out_linear_weight_data = + d_out_linear_weight->mutable_data(ctx.GetPlace()); + auto *d_out_linear_bias_data = + d_out_linear_bias->mutable_data(ctx.GetPlace()); +#if 1 + auto *d_ln_2_scale_data = + (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( + ctx.GetPlace())); + auto *d_ln_2_bias_data = + (d_ln_2_bias == nullptr ? nullptr + : d_ln_2_bias->mutable_data(ctx.GetPlace())); +#endif + + // get data ptr for qkv part. + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + Tensor d_residual; + d_residual.Resize(input_x_dims); + T *d_residual_data = d_residual.mutable_data(ctx.GetPlace()); + + bool transA = false; + bool transB = true; + bool compute_bias = true; + auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), + epsilon, bsz_seq, dim_embed); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); + // fmha + AttnDropoutParam attn_dropout_param( + is_test_1, dropout_implementation_1, attn_dropout_prob, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + auto fmha_ref_compute = + FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, + dim_head, attn_dropout_param); + // out_linear + output_size = hidden_size; + transA = false; + transB = false; + compute_bias = false; + auto out_linear_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); +#if 1 + // bias + dropout + residual + layernorm. + DropoutParam dropout_param2(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln2epsilon); +#endif +#if 1 + // dout -> dlayernorm_dsrc, dscale, layernorm_dbias + // dlayernorm_dsrc -> dsrc, dbias, dresidual + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, + dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, + d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, + d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); +#endif +#if 1 + out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, + d_out_linear_out_data, d_fmha_out_data, + d_out_linear_weight_data, nullptr); +#endif +#if 1 + fmha_ref_compute.ComputeBackward( + *transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_bias_out); + // d_qkv_bias_out->d_qkv_out + // batch_size, seq_len, 3, num_head, head_size + cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, + bsz_seq * 3 * num_head * dim_head * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif +#if 1 + // get qkv + if (pre_layer_norm) { + qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, + d_qkv_bias_out_data, d_ln_out_data, + d_qkv_weight_data, d_qkv_bias_data); + layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, + ln_mean_data, ln_var_data, d_x_data, + d_ln_scale_data, d_ln_bias_data); + } else { + qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data, + d_x_data, d_qkv_weight_data, d_qkv_bias_data); + } + // gradient accumulation: d_x[] + d_residual[] = d_x[] + std::vector ins; + std::vector outs; + ins.emplace_back(&d_residual); + ins.emplace_back(d_x); + outs.emplace_back(d_x); + int elewise_add_axis = -1; + LaunchElementwiseCudaKernel( + ctx.cuda_device_context(), ins, &outs, elewise_add_axis, + AddFunctor()); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel); +REGISTER_OP_CUDA_KERNEL(fused_attention_grad, + ops::FusedAttentionGradKernel, + ops::FusedAttentionGradKernel, + ops::FusedAttentionGradKernel); diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h new file mode 100644 index 00000000000000..826a06e9bfe00b --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -0,0 +1,298 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/functors.h" +#include "paddle/fluid/operators/math/math_function.h" +#ifdef __NVCC__ +#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" +#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" +#endif + +namespace paddle { +namespace operators { + +struct DropoutParam { + uint64_t seed; + float dropout_prob; + bool is_upscale_in_train; + bool is_test; + bool fix_seed; + int increment; + bool has_increment; + + DropoutParam() { + fix_seed = false; + seed = 0; + is_test = false; + is_upscale_in_train = false; + has_increment = false; + dropout_prob = 0.5; + } + + /** + * dropout_index: the index of dropout, such as FFN has two dropout, + * so the dropout_index will 1 or 2. + * the dropout param will defined as param1 or param2 + */ + DropoutParam(const framework::ExecutionContext& context, + const int dropout_index) { + std::string str_index = std::to_string(dropout_index); + if (dropout_index == 0) { + str_index = ""; + } + dropout_prob = context.Attr("dropout_prob" + str_index); + auto& dropout_implementation = + context.Attr("dropout_implementation" + str_index); + is_upscale_in_train = (dropout_implementation == "upscale_in_train"); + is_test = context.Attr("is_test" + str_index); + fix_seed = context.Attr("fix_seed" + str_index); + has_increment = false; + + std::string str_seed = "Seed" + str_index; + auto* tensor_seed = + context.HasInput(str_seed) ? context.Input(str_seed) : nullptr; + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (tensor_seed && platform::is_gpu_place(tensor_seed->place())) { + framework::Tensor seed_cpu_tensor; + TensorCopySync(*tensor_seed, platform::CPUPlace(), &seed_cpu_tensor); + seed = static_cast(seed_cpu_tensor.data()[0]); + } else if (gen_cuda->GetIsInitPy() && !fix_seed) { + has_increment = true; + } else { + if (tensor_seed) { + seed = *(tensor_seed->data()); + } else { + std::random_device rnd; + seed = fix_seed ? context.Attr("seed" + str_index) : rnd(); + } + } + } + int UpdateSeedAndIncrement(const platform::CUDADeviceContext& ctx, + const int offset) { + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto seed_offset = gen_cuda->IncrementOffset(offset); + seed = seed_offset.first; + increment = static_cast(seed_offset.second); + return increment; + } +}; + +template +class FusedDropoutHelper { + private: + int GetIncrement(const platform::CUDADeviceContext& ctx) { + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1; + auto config = + Get1DBlocksAnd2DGrids(ctx, static_cast(rows_), + static_cast(cols_), real_vec_size); + int increment = ((cols_ - 1) / (config.thread_per_block.x * + config.block_per_grid.x * real_vec_size) + + 1) * + real_vec_size; + if (dropout_param_.has_increment) { + increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment); + } + return increment; + } + + public: + FusedDropoutHelper() {} + FusedDropoutHelper(const platform::CUDADeviceContext& ctx, const int rows, + const int cols, const DropoutParam& dropout_param) { + rows_ = rows; + cols_ = cols; + dropout_param_ = dropout_param; + } + + // out = residual + dropout( src + bias ) + void ResidualDropoutBias(const platform::CUDADeviceContext& ctx, const T* src, + const T* residual, const T* bias, T* out, + MaskType* mask) { + auto increment = GetIncrement(ctx); + LaunchResidualDropoutBias( + rows_, cols_, increment, dropout_param_.seed, + dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, + dropout_param_.is_test, src, residual, bias, mask, out, ctx); + } + + void ResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx, + const T* dout, const MaskType* mask, T* dsrc, + T* dresidual, T* dbias) { + LaunchResidualDropoutBiasGrad( + dout, mask, dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); + cudaMemcpyAsync(dresidual, dout, rows_ * cols_ * sizeof(T), + cudaMemcpyDeviceToDevice); + } + + // out = dropout(activation(src + bias)) + void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src, + const T* bias, const std::string& act_method, T* out, + MaskType* mask) { + auto increment = GetIncrement(ctx); + if (act_method == "gelu") { + GeluFunctor gelu; + LaunchDropoutActBias>( + gelu, dropout_param_.seed, rows_, cols_, dropout_param_.increment, + dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, + dropout_param_.is_test, src, bias, out, mask, ctx); + } else if (act_method == "relu") { + math::ReluFunctor relu; + LaunchDropoutActBias>( + relu, dropout_param_.seed, rows_, cols_, increment, + dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, + dropout_param_.is_test, src, bias, out, mask, ctx); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "the activation only support gelu or relu!")); + } + } + + void DropoutActBiasGrad(const platform::CUDADeviceContext& ctx, const T* dout, + const T* src, const T* bias, const MaskType* mask, + T* dsrc, T* dbias, const std::string& act_method) { + if (act_method == "gelu") { + GeluGradFunctor gelu_grad; + LaunchDropoutActBiasGrad>( + gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); + } else if (act_method == "relu") { + math::ReluGradFunctor relu_grad; + LaunchDropoutActBiasGrad>( + relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "the activation only support gelu or relu!")); + } + } + + protected: + int rows_; + int cols_; + DropoutParam dropout_param_; +}; + +template +class FusedDropoutLayerNormHelper : public FusedDropoutHelper { + public: + FusedDropoutLayerNormHelper() {} + FusedDropoutLayerNormHelper(const int rows, const int cols, + const float epsilon) { + using U = LayerNormParamType; + this->rows_ = rows; + this->cols_ = cols; + epsilon_ = epsilon; + } + + FusedDropoutLayerNormHelper(const platform::CUDADeviceContext& ctx, + const int rows, const int cols, + const DropoutParam& dropout_param, + const float epsilon) + : FusedDropoutHelper(ctx, rows, cols, dropout_param) { + using U = LayerNormParamType; + epsilon_ = epsilon; + } + + // call layer_norm + void LayerNorm(const platform::CUDADeviceContext& ctx, const T* src, + const LayerNormParamType* gamma, + const LayerNormParamType* beta, T* out, + LayerNormParamType* mean, LayerNormParamType* variance) { +#ifdef __NVCC__ + using U = LayerNormParamType; + switch (GetDesiredBlockDim(this->cols_)) { + FIXED_BLOCK_DIM_CASE( + LayerNormForward< + T, U, kBlockDim><<rows_, kBlockDim, 0, ctx.stream()>>>( + src, gamma, beta, out, mean, variance, epsilon_, this->cols_)); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Product from begin_norm_axis to end must be larger than 1")); + break; + } +#endif + } + + void LayerNormGrad(const platform::CUDADeviceContext& ctx, const T* dout, + const T* src, const LayerNormParamType* gamma, + const LayerNormParamType* mean, + const LayerNormParamType* variance, T* dsrc, + LayerNormParamType* dscale, + LayerNormParamType* dbias) { + using U = LayerNormParamType; + LayerNormBackward(src, dout, gamma, mean, variance, dsrc, dscale, + dbias, epsilon_, this->rows_, this->cols_, ctx); + } + + // out = layernorm(residual + dropout(src + bias)) + void LayernormResidualDropoutBias( + const platform::CUDADeviceContext& ctx, const T* src, const T* residual, + const T* bias, const LayerNormParamType* gamma, + const LayerNormParamType* beta, T* dropout_out, MaskType* mask, T* out, + LayerNormParamType* mean, LayerNormParamType* variance) { +#ifdef __NVCC__ + using U = LayerNormParamType; + int VecSize = MAX_CACHE_BYTES / sizeof(T); + if (this->cols_ % VecSize != 0) { + VecSize = 1; + } + int threads = GetDesiredBlockDim(this->cols_ / VecSize); + + int increment = ((this->cols_ - 1) / (threads * VecSize) + 1) * VecSize; + if (this->dropout_param_.has_increment) { + increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); + } + + LaunchLayernormResidualDropoutBias( + this->rows_, this->cols_, increment, this->dropout_param_.seed, + this->dropout_param_.dropout_prob, epsilon_, + this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test, + src, residual, bias, gamma, beta, mask, dropout_out, out, mean, + variance, ctx); +#endif + } + + void LayernormResidualDropoutBiasGrad( + const platform::CUDADeviceContext& ctx, const T* dout, const T* src, + const MaskType* mask, const LayerNormParamType* gamma, + const LayerNormParamType* mean, const LayerNormParamType* variance, + T* layernorm_dsrc, LayerNormParamType* dscale, + LayerNormParamType* layernorm_dbias, T* dsrc, T* dbias, T* dresidual) { +#ifdef __NVCC__ + using U = LayerNormParamType; + LayerNormBackward(src, dout, gamma, mean, variance, layernorm_dsrc, + dscale, layernorm_dbias, epsilon_, this->rows_, + this->cols_, ctx); + this->ResidualDropoutBiasGrad(ctx, layernorm_dsrc, mask, dsrc, dresidual, + dbias); +#endif + } + + protected: + float epsilon_; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py new file mode 100644 index 00000000000000..7510ea2a5fe263 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -0,0 +1,303 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.fluid.data_feeder import convert_dtype +from paddle import tensor +from paddle.fluid import layers + +import unittest + +place = paddle.CUDAPlace(0) + + +def _convert_attention_mask(attn_mask, dtype): + """ + Convert the attention mask to the target dtype we expect. + Parameters: + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + dtype (VarType): The target type of `attn_mask` we expect. + Returns: + Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`. + """ + if attn_mask is not None and attn_mask.dtype != dtype: + attn_mask_dtype = convert_dtype(attn_mask.dtype) + if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: + attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9 + else: + attn_mask = paddle.cast(attn_mask, dtype) + return attn_mask + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedAttentionOp(unittest.TestCase): + def setUp(self): + self.config() + self.generate_input_data() + paddle.set_default_dtype(self.x_type) + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + + self.key, self.value = self.query, self.query + + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + def GetBaselineOut(self): + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + #residual = paddle.to_tensor(self.query) + residual = tensor_query + + for i in range(1): + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + # get q, k, v + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # q_out * k^t + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + # Support bool or int mask + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + # combine heads + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + + out_linear_in = tensor.reshape( + x=fmha_out, + shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + # project to output + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + if self.pre_layer_norm: + final_out = self.norm2(residual_out) + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + return final_out, tensor_query.grad + + def GetFusedAttentionOut(self): + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate((q_proj_weight, k_proj_weight)) + qkv_weight = np.concatenate((qkv_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate((q_proj_bias.numpy(), k_proj_bias.numpy())) + qkv_bias = np.concatenate((qkv_bias, v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + # Support bool or int mask + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + + for i in range(1): + final_out = F.fused_multihead_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, + qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + + return final_out, x.grad + + def test_fused_attention_op(self): + print( + "self.batch_size, self.query_length, self.embed_dim, self.num_heads, self.head_dim = " + ) + print(self.batch_size, self.query_length, self.embed_dim, + self.num_heads, self.head_dim) + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedAttentionOpFp16(TestFusedAttentionOp): + def config(self): + self.x_type = np.float16 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + + self.weight_attr = None + self.bias_attr = None + + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + print( + "self.batch_size, self.query_length, self.embed_dim, self.num_heads, self.head_dim = " + ) + print(self.batch_size, self.query_length, self.embed_dim, + self.num_heads, self.head_dim) + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + + +if __name__ == "__main__": + unittest.main() From 42f03726c7323c5182a969b81298d96db69ca049 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 07:11:37 +0000 Subject: [PATCH 03/29] Add fused_attention_op: forward impl. --- cmake/operators.cmake | 2 +- paddle/fluid/operators/fused/CMakeLists.txt | 4 + .../operators/fused/fused_attention_op.h | 18 ++ paddle/fluid/pybind/op_function_generator.cc | 8 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/common.py | 168 ++++++++++++++++++ 6 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/fused/fused_attention_op.h diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 2c010a1e6297f0..570cb142eeb8f5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -216,7 +216,7 @@ function(op_library TARGET) "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op") +"fused_bn_add_activation_op" "fused_attention_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index e3dcff949f43c3..d16033e9e16153 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -16,6 +16,7 @@ register_operators(EXCLUDES fusion_gru_op fusion_lstm_op fused_bn_add_activation_op + fused_attention_op fused_transformer_op) # fusion_gru_op does not have CUDA kernel @@ -59,6 +60,9 @@ if (WITH_GPU OR WITH_ROCM) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n") op_library(fused_embedding_eltwise_layernorm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n") + # fused_attention_op + op_library(fused_attention_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") # fusion_group if(NOT APPLE AND NOT WIN32) op_library(fusion_group_op DEPS device_code) diff --git a/paddle/fluid/operators/fused/fused_attention_op.h b/paddle/fluid/operators/fused/fused_attention_op.h new file mode 100644 index 00000000000000..44d532960b7a5c --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.h @@ -0,0 +1,18 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { +// todo: +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index f9d11e8154f43f..5a3d99239750bd 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -40,6 +40,9 @@ // need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, + {"fused_attention", + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", + "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -87,6 +90,11 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"fused_attention", + {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", + "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7965b362b9c55a..d8bec647f2c54c 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,6 +60,7 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 +from .common import fused_multihead_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index fcfbea438d7cca..717a9361d7736d 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1502,6 +1502,174 @@ def linear(x, weight, bias=None, name=None): return res +def fused_multihead_attention(x, + qkv_weight, + out_linear_weight, + pre_layer_norm=False, + ln_scale=None, + ln_bias=None, + ln_2_scale=None, + ln_2_bias=None, + epsilon=1e-05, + qkv_bias=None, + out_linear_bias=None, + src_mask=None, + dropout=0., + attn_dropout=0., + ln2_epsilon=1e-05, + name=None): + r""" + Fused multihead_attention operator. For each input :math:`X` , + the equation is: + .. math:: + if (pre_layer_norm) + ln_out = layer_norm(X) + qkv_out = get_qkv(ln_out, qkv_weight, qkv_bias) + else + qkv_out = get_qkv(X, qkv_weight, qkv_bias) + fmha_out = fmha(qkv_out, xxx) + out_linear_out = out_linear(fmha_out, out_linear_weight) + out = bias_dropout_residual_layer_norm(out_linear_out, out_linear_bias) + # where :math:`W` is the weight and :math:`b` is the bias. + # If the weight is a 2-D tensor of shape :math:`[in\_features, out\_features]` , + # input should be a multi-dimensional tensor of shape + # :math:`[batch\_size, *, in\_features]` , where :math:`*` means any number of + # additional dimensions. The linear operator multiplies input tensor with + # weight and produces an output tensor of shape :math:`[batch\_size, *, out\_features]` , + # If :math:`bias` is not None, the bias should be a 1-D tensor of shape + # :math:`[out\_features]` and will be added to the output. + Parameters: + x (Tensor): Input tensor with shape [batch\_size, seq\_len, dim_embed]. + The data type should be float16, float32 or float64. + qkv_weight (Tensor): QKV Weight tensor with shape [3 * num_head * dim_head, dim_embed]. + The data type should be float16, float32 or float64. + ## tood: check shape! + out_linear_weight (Tensor): Out_linear Weight tensor with shape [xx, dim_embed]. + The data type should be float16, float32 or float64. + qkv_bias (Tensor, optional): QKV Bias tensor with shape [3 * num_head * dim_head]. + The data type should be float16, float32 or float64. + If it is set to None, no bias will be added to the output units. + ## tood: check shape! + out_linear_bias (Tensor, optional): Out_linear Bias tensor with shape []. + The data type should be float16, float32 or float64. + If it is set to None, no bias will be added to the output units. + name (str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + Returns: + Tensor, the shape is :math:`[batch\_size, seq\_len, dim_embed]` and the + data type is the same with input :math:`x` . + Examples: + .. code-block:: python + + import paddle + """ + if in_dygraph_mode(): + ## before integrate kaihuo's code + #ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, src_mask_out, fmha_out, out_linear_out, final_out = _C_ops.fused_attention(x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, out_linear_weight, out_linear_bias, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout) + + ## finally code + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, + out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + #return ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, src_mask_out, fmha_out, out_linear_out, bias_dropout_residual_out, final_out + return final_out + else: + helper = LayerHelper('fused_multihead_attention', **locals()) + dtype = x.dtype + # check dtypes + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'fused_multihead_attention') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], + 'fused_multihead_attention') + + # set inputs + inputs = dict() + inputs['X'] = [x] + if ln_scale: + inputs['LnScale'] = [ln_scale] + if ln_bias: + inputs['LnBias'] = [ln_bias] + inputs['QKVW'] = [qkv_weight] + inputs['QKVBias'] = [qkv_bias] + inputs['SrcMask'] = src_mask + inputs['OutLinearW'] = [out_linear_weight] + inputs['OutLinearBias'] = [out_linear_bias] + if ln_2_scale: + inputs['Ln2Scale'] = [ln_2_scale] + if ln_2_bias: + inputs['Ln2Bias'] = [ln_2_bias] + + # set attrs + attrs = { + 'pre_layer_norm': pre_layer_norm, + 'epsilon': epsilon, + 'ln2_epsilon': ln2_epsilon, + 'dropout_prob': dropout, + 'attn_dropout_prob': attn_dropout + } + + # set outputs + ln_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_out = helper.create_variable_for_type_inference(dtype=dtype) + + qkv_out = helper.create_variable_for_type_inference(dtype=dtype) + qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) + + transpose_out_2 = helper.create_variable_for_type_inference(dtype=dtype) + qk_out = helper.create_variable_for_type_inference(dtype=dtype) + qktv_out = helper.create_variable_for_type_inference(dtype=dtype) + softmax_out = helper.create_variable_for_type_inference(dtype=dtype) + attn_dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + attn_dropout_out = helper.create_variable_for_type_inference( + dtype=dtype) + # todo: stop_gradient? + src_mask_out = helper.create_variable_for_type_inference(dtype=dtype) + fmha_out = helper.create_variable_for_type_inference(dtype=dtype) + out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) + dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + ln_2_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_2_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + bias_dropout_residual_out = helper.create_variable_for_type_inference( + dtype=dtype) + final_out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='fused_attention', + inputs=inputs, + outputs={ + "LnMean": ln_mean_out, + "LnVariance": ln_variance_out, + "LnOut": ln_out, + "QKVOut": qkv_out, + "QKVBiasOut": qkv_bias_out, + "TransposeOut2": transpose_out_2, + "QKOut": qk_out, + "QKTVOut": qktv_out, + "SoftmaxOut": softmax_out, + "AttnDropoutMaskOut": attn_dropout_mask_out, + "AttnDropoutOut": attn_dropout_out, + "SrcMaskOut": src_mask_out, + "FMHAOut": fmha_out, + "OutLinearOut": out_linear_out, + "DropoutMaskOut": dropout_mask_out, + "Ln2Mean": ln_2_mean_out, + "Ln2Variance": ln_2_variance_out, + "BiasDropoutResidualOut": bias_dropout_residual_out, + 'Y': final_out + }, + attrs=attrs) + return final_out + + def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called From c6aebef71cd8b47892a1b68ea47c704ef1e05bb3 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 08:07:53 +0000 Subject: [PATCH 04/29] Remove useless code. --- .../operators/fused/fused_attention_op.cc | 298 +---------------- .../operators/fused/fused_attention_op.cu | 307 +----------------- .../unittests/test_fused_attention_op.py | 49 +-- 3 files changed, 27 insertions(+), 627 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 1c3db42a6177b3..19f87472eb5184 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -10,10 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/fused_attention_op.h" - #include #include - #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -26,8 +24,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - // std::cout << "i am in op infershape\n"; - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", "FusedAttentionOp"); @@ -39,13 +35,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", "FusedAttentionOp"); - // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp"); + // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", @@ -68,7 +64,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", "FusedAttentionOp"); -#if 1 OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", @@ -77,21 +72,12 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "BiasDropoutResidualOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", "FusedAttentionOp"); -#endif OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputDim("QKVW"); - // auto qkv_bias_dim = ctx->GetInputDim("QKVBias"); - // auto src_mask_dim = ctx->GetInputDim("SrcMask"); - // std::cout << "x_dim = " << x_dim << std::endl; - // std::cout << "qkv_weight_dim = " << y_dim << std::endl; - // std::cout << "qkv_bias_dim = " << qkv_bias_dim << std::endl; - // // src_mask_dim = 32, 16, 128, 128 - // std::cout << "src_mask_dim = " << src_mask_dim << std::endl; - PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( "The dimensions of QKV_input must be 3" @@ -99,7 +85,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "but received dimensions of" "Input is [%d]", x_dim.size())); - PADDLE_ENFORCE_EQ(y_dim.size(), 4, platform::errors::InvalidArgument( "The dimensions of QKV_weight must be 4" @@ -107,8 +92,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "but received dimensions of" "Input is [%d]", y_dim.size())); - - // limin-todo: polish the expression. PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], platform::errors::InvalidArgument( "ShapeError: the dimension of x_dim[2] and y_dim[3]" @@ -125,11 +108,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); - // limin-todo: [3, batch_size, seq_len, num_head, head_size] - // check shape: [3, batch_size, num_head, seq_len, head_size] + // [3, batch_size, num_head, seq_len, head_size] ctx->SetOutputDim("TransposeOut2", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); - // check shape: batch, num_head, seq_len, seq_len + // [batch, num_head, seq_len, seq_len] ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); // the same as QKOut's shape. @@ -140,20 +122,18 @@ class FusedAttentionOp : public framework::OperatorWithKernel { {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); } ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - // check shape [batch_size, num_heads, seq_len, head_dim] + // [batch_size, num_heads, seq_len, head_dim] ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); - // check shape, [batch_size, seq_len, number of heads*head size] + // [batch_size, seq_len, number of heads*head size] ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); -#if 1 ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get("is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); -#endif ctx->SetOutputDim("Y", ctx->GetInputDim("X")); } @@ -186,7 +166,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor."); AddInput("OutLinearBias", "The out_linear bias tensor."); -#if 1 AddInput("Ln2Scale", "(optional) Scale is a 1-dimensional tensor of size " "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." @@ -197,23 +176,17 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." "It is applied to the output.") .AsDispensable(); -#endif -#if 1 -// todo: -// AddInput("Seed", -// "The seed of dropout op, it has higher priority than the attr " -// "fix_seed and seed") -// .AsDispensable(); -#endif + // AddInput("Seed", + // "The seed of dropout op, it has higher priority than the attr + // " + // "fix_seed and seed") + // .AsDispensable(); AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); AddOutput("LnVariance", "Variance of the current mini batch.") .AsIntermediate(); AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); - AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); - - // fma AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); AddOutput("QKOut", "Result in fmha.").AsIntermediate(); AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); @@ -222,10 +195,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); - AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); - -#if 1 AddOutput("DropoutMaskOut", "The random sampled dropout mask.") .AsIntermediate(); AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); @@ -234,8 +204,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("BiasDropoutResidualOut", "Result of residual + dropout(src + bias).") .AsIntermediate(); -#endif - AddOutput("Y", "Result after attention."); AddAttr("pre_layer_norm", @@ -253,19 +221,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "0.0 and 0.001, But received [%s].", epsilon)); }); - // AddAttr("begin_norm_axis", - // "the axis of `begin_norm_axis ... Rank(X) - 1` will be " - // "normalized. `begin_norm_axis` splits the tensor(`X`) to a " - // "matrix [N,H]. [default 1].") - // .SetDefault(1) - // .AddCustomChecker([](const int &begin_norm_axis) { - // PADDLE_ENFORCE_GT(begin_norm_axis, 0, - // platform::errors::InvalidArgument( - // "'begin_norm_axis' in Op(LayerNorm) should - // be" - // "greater than zero. But received [%d].", - // begin_norm_axis)); - // }); // for dropout in fmha. AddAttr("attn_dropout_prob", "Probability of setting units to zero.") @@ -313,7 +268,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "upscale_in_train")); }); -#if 1 AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float &drop_p) { @@ -369,7 +323,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "0.0 and 0.001, But received [%s].", ln2epsilon)); }); -#endif AddComment(R"DOC( Fused attention: @@ -383,238 +336,9 @@ bias_add + dropout + residual + layer_norm; } }; -class FusedAttentionGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { -// auto x_dim = ctx->GetInputDim("X"); -// auto y_dim = ctx->GetInputDim("QKVW"); -// std::cout << "x_dim = " << x_dim << std::endl; -// std::cout << "y_dim = " << y_dim << std::endl; -// int batch_size = x_dim[0]; -// int seq_len = x_dim[1]; -// int embed_dim = x_dim[2]; -// std::cout << "batch_size, seq_len, embed_dim= " << batch_size << ", " << -// seq_len << ", " << embed_dim << std::endl; - -#if 1 - PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_test"), false, - platform::errors::InvalidArgument( - "GradOp is only callable when is_test is false")); - - OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", - "FusedAttentionGrad"); - if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), - ctx->GetInputDim("Ln2Scale")); - } - if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), - ctx->GetInputDim("Ln2Bias")); - } -#endif - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", - "FusedAttentionGrad"); - if (ctx->Attrs().Get("pre_layer_norm") == true) { - OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", - "FusedAttentionGrad"); - } - OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", - "FusedAttentionGrad"); - - if (ctx->HasOutput(framework::GradVarName("LnScale"))) { - ctx->SetOutputDim(framework::GradVarName("LnScale"), - ctx->GetInputDim("LnScale")); - } - if (ctx->HasOutput(framework::GradVarName("LnBias"))) { - ctx->SetOutputDim(framework::GradVarName("LnBias"), - ctx->GetInputDim("LnBias")); - } - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } - - ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), - ctx->GetInputDim("OutLinearBias")); - ctx->SetOutputDim(framework::GradVarName("OutLinearW"), - ctx->GetInputDim("OutLinearW")); - ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); - ctx->SetOutputDim(framework::GradVarName("QKVBias"), - ctx->GetInputDim("QKVBias")); - - ctx->SetOutputDim(framework::GradVarName("LnOut"), - ctx->GetInputDim("LnOut")); - ctx->SetOutputDim(framework::GradVarName("FMHAOut"), - ctx->GetInputDim("FMHAOut")); - ctx->SetOutputDim(framework::GradVarName("QKTVOut"), - ctx->GetInputDim("QKTVOut")); - ctx->SetOutputDim(framework::GradVarName("TransposeOut2"), - ctx->GetInputDim("TransposeOut2")); - ctx->SetOutputDim(framework::GradVarName("QKOut"), - ctx->GetInputDim("QKOut")); - ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), - ctx->GetInputDim("SoftmaxOut")); - ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), - ctx->GetInputDim("AttnDropoutOut")); - ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), - ctx->GetInputDim("SrcMaskOut")); - ctx->SetOutputDim(framework::GradVarName("QKVOut"), - ctx->GetInputDim("QKVOut")); - ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), - ctx->GetInputDim("QKVBiasOut")); -#if 1 - ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), - ctx->GetInputDim("OutLinearOut")); - // ctx->SetOutputDim(framework::GradVarName("DropoutMaskOut"), - // ctx->GetInputDim("DropoutMaskOut")); - ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), - ctx->GetInputDim("BiasDropoutResidualOut")); -#endif - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input = ctx.Input("X"); - auto input_data_type = input->type(); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } -}; - -template -class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("fused_attention_grad"); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - - // inputs x, parameters and their grad. - op->SetInput("X", this->Input("X")); - op->SetInput("QKVW", this->Input("QKVW")); - op->SetInput("QKVBias", this->Input("QKVBias")); - op->SetInput("SrcMask", this->Input("SrcMask")); - op->SetInput("OutLinearW", this->Input("OutLinearW")); - op->SetInput("OutLinearBias", this->Input("OutLinearBias")); - if (this->HasInput("LnScale")) { - op->SetInput("LnScale", this->Input("LnScale")); - op->SetOutput(framework::GradVarName("LnScale"), - this->InputGrad("LnScale")); - } - if (this->HasInput("LnBias")) { - op->SetInput("LnBias", this->Input("LnBias")); - op->SetOutput(framework::GradVarName("LnBias"), - this->InputGrad("LnBias")); - } -#if 1 - if (this->HasInput("Ln2Scale")) { - op->SetInput("Ln2Scale", this->Input("Ln2Scale")); - op->SetOutput(framework::GradVarName("Ln2Scale"), - this->InputGrad("Ln2Scale")); - } - if (this->HasInput("Ln2Bias")) { - op->SetInput("Ln2Bias", this->Input("Ln2Bias")); - op->SetOutput(framework::GradVarName("Ln2Bias"), - this->InputGrad("Ln2Bias")); - } -#endif - - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); - op->SetOutput(framework::GradVarName("QKVBias"), - this->InputGrad("QKVBias")); - op->SetOutput(framework::GradVarName("OutLinearBias"), - this->InputGrad("OutLinearBias")); - op->SetOutput(framework::GradVarName("OutLinearW"), - this->InputGrad("OutLinearW")); - - // use forward's output as bw's input. - op->SetInput("LnOut", this->Output("LnOut")); - op->SetInput("LnMean", this->Output("LnMean")); - op->SetInput("LnVariance", this->Output("LnVariance")); - op->SetInput("QKVOut", this->Output("QKVOut")); - op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); - op->SetInput("TransposeOut2", this->Output("TransposeOut2")); - op->SetInput("QKOut", this->Output("QKOut")); - op->SetInput("QKTVOut", this->Output("QKTVOut")); - op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); - op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); - op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); - op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); - op->SetInput("FMHAOut", this->Output("FMHAOut")); - op->SetInput("OutLinearOut", this->Output("OutLinearOut")); - -#if 1 - op->SetInput("Ln2Mean", this->Output("Ln2Mean")); - op->SetInput("Ln2Variance", this->Output("Ln2Variance")); - op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); - op->SetInput("BiasDropoutResidualOut", - this->Output("BiasDropoutResidualOut")); -#endif - // op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); - op->SetInput("QKVOut", this->Output("QKVOut")); - - // bw's output: dinput - op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); - op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); - op->SetOutput(framework::GradVarName("QKVBiasOut"), - this->OutputGrad("QKVBiasOut")); - op->SetOutput(framework::GradVarName("QKTVOut"), - this->OutputGrad("QKTVOut")); - op->SetOutput(framework::GradVarName("TransposeOut2"), - this->OutputGrad("TransposeOut2")); - op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut")); - op->SetOutput(framework::GradVarName("SoftmaxOut"), - this->OutputGrad("SoftmaxOut")); - op->SetOutput(framework::GradVarName("AttnDropoutOut"), - this->OutputGrad("AttnDropoutOut")); - op->SetOutput(framework::GradVarName("SrcMaskOut"), - this->OutputGrad("SrcMaskOut")); - op->SetOutput(framework::GradVarName("FMHAOut"), - this->OutputGrad("FMHAOut")); -#if 1 - // op->SetOutput(framework::GradVarName("DropoutMaskOut"), - // this->OutputGrad("DropoutMaskOut")); - op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), - this->OutputGrad("BiasDropoutResidualOut")); -#endif - op->SetOutput(framework::GradVarName("OutLinearOut"), - this->OutputGrad("OutLinearOut")); - // op->SetOutput(framework::GradVarName("OutLinearBiasOut"), - // this->OutputGrad("OutLinearBiasOut")); - - op->SetAttrMap(this->Attrs()); - } -}; - -// DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseAddLayerNormGradNoNeedBufferVarInferer, -// "Bias"); - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, - ops::FusedAttentionOpMaker, - ops::FusedAttentionGradOpMaker, - ops::FusedAttentionGradOpMaker); - -REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); -// REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp, -// ops::FusedAttentionGradNoNeedBufferVarInferer); + ops::FusedAttentionOpMaker); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 83a9287bf2e232..b6e385deb8dde3 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -66,7 +66,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_out = ctx.Output("QKVOut"); auto *qkv_bias_out = ctx.Output("QKVBiasOut"); - // FMHA-ref: auto *src_mask = ctx.Input("SrcMask"); auto *transpose_out_2 = ctx.Output("TransposeOut2"); auto *qk_out = ctx.Output("QKOut"); @@ -77,13 +76,10 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *src_mask_out = ctx.Output("SrcMaskOut"); auto *fmha_out = ctx.Output("FMHAOut"); - // out_linear auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *out_linear_out = ctx.Output("OutLinearOut"); -// bias+dropout+residual+layernorm -#if 1 auto *ln_scale_2 = ctx.Input("Ln2Scale"); auto *ln_bias_2 = ctx.Input("Ln2Bias"); auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); @@ -92,12 +88,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *ln_mean_2 = ctx.Output("Ln2Mean"); auto *ln_var_2 = ctx.Output("Ln2Variance"); const float ln2epsilon = ctx.Attr("ln2epsilon"); -#endif -#if 1 float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); - std::cout << "limin: attn_dropout_prob = " << attn_dropout_prob - << std::endl; bool is_test_1 = ctx.Attr("is_test1"); auto &dropout_implementation_1 = ctx.Attr("dropout_implementation1"); @@ -106,7 +98,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("fix_seed1"); int seed_val_1 = ctx.Attr("seed1"); -#endif // final output. auto *out = ctx.Output("Y"); @@ -146,8 +137,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *out_linear_bias_data = out_linear_bias->data(); auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); -// get data ptr for bias+dropout+residual+layernorm -#if 1 + // get data ptr for bias+dropout+residual+layernorm auto *ln_scale_2_data = (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); auto *ln_bias_2_data = @@ -158,7 +148,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { bias_dropout_residual_out->mutable_data(ctx.GetPlace()); auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); -#endif auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; @@ -182,16 +171,13 @@ class FusedAttentionOpKernel : public framework::OpKernel { AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_bias); - // AttnDropoutParam(bool is_test, const std::string dropout_implementation, - // float dropout_prob, bool is_upscale_in_train, - // bool is_fix_seed, int seed_val, const Tensor* seed) { AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); - // out_linear + output_size = hidden_size; transA = false; transB = false; @@ -199,14 +185,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_bias); -#if 1 DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln2epsilon); -#endif - // compute if (pre_layer_norm) { layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, ln_out_data, ln_mean_data, ln_var_data); @@ -216,297 +199,21 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_out_data, qkv_bias_out_data); } - // compute FMHA fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); -// fmha_out: [batch_size, seq_len, num_head, head_dim] -// weight: [1024, 1024], [embed_dim, embed_dim] -// out_linear_out: [batch_size, seq_len, embed_dim] -#if 1 + // fmha_out: [batch_size, seq_len, num_head, head_dim] + // weight: [1024, 1024], [embed_dim, embed_dim] + // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, nullptr, out_linear_out_data, nullptr); -#endif -#if 1 // out = layernorm(residual + dropout(src + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, x_data, out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, ln_mean_2_data, ln_var_2_data); -#endif - } -}; - -template -class FusedAttentionGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); -#if 1 - const float ln2epsilon = ctx.Attr("ln2epsilon"); -#endif - -#if 1 - float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); - bool is_test_1 = ctx.Attr("is_test1"); - auto &dropout_implementation_1 = - ctx.Attr("dropout_implementation1"); - bool is_upscale_in_train_1 = - (dropout_implementation_1 == "upscale_in_train"); - auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - bool is_fix_seed_1 = ctx.Attr("fix_seed1"); - int seed_val_1 = ctx.Attr("seed1"); -#endif - - // get inputs. - auto *d_y = ctx.Input(framework::GradVarName("Y")); - auto *d_y_data = d_y->data(); - - // fw input - auto *input_x = ctx.Input("X"); - auto *ln_scale = ctx.Input("LnScale"); -#if 1 - auto *ln_2_scale = ctx.Input("Ln2Scale"); -#endif - auto *x_data = input_x->data(); - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); -#if 1 - auto *ln_2_scale_data = - (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); -#endif - // fw parameters. - auto *src_mask = ctx.Input("SrcMask"); - auto *qkv_weight = ctx.Input("QKVW"); - auto *qkv_bias = ctx.Input("QKVBias"); - auto *out_linear_weight = ctx.Input("OutLinearW"); - auto *out_linear_bias = ctx.Input("OutLinearBias"); - auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); - auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = qkv_bias->data(); - auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); - - // fw output - auto *ln_mean = ctx.Input("LnMean"); - auto *ln_var = ctx.Input("LnVariance"); - auto *ln_out = ctx.Input("LnOut"); - auto *fmha_out = ctx.Input("FMHAOut"); - auto *transpose_out_2 = ctx.Input("TransposeOut2"); - auto *qk_out = ctx.Input("QKOut"); - auto *qktv_out = ctx.Input("QKTVOut"); - auto *softmax_out = ctx.Input("SoftmaxOut"); - auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); - auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); - auto *src_mask_out = ctx.Input("SrcMaskOut"); - auto *out_linear_out = ctx.Input("OutLinearOut"); -#if 1 - auto *ln_2_mean = ctx.Input("Ln2Mean"); - auto *ln_2_var = ctx.Input("Ln2Variance"); - auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); - auto *bias_dropout_residual_out = - ctx.Input("BiasDropoutResidualOut"); -#endif - auto *ln_mean_data = ln_mean->data(); - auto *ln_var_data = ln_var->data(); - auto *ln_out_data = ln_out->data(); - auto *fmha_out_data = fmha_out->data(); - auto *transpose_out_2_data = transpose_out_2->data(); - auto *qk_out_data = qk_out->data(); - auto *qktv_out_data = qktv_out->data(); - auto *softmax_out_data = softmax_out->data(); - auto *src_mask_out_data = src_mask_out->data(); - auto *out_linear_out_data = out_linear_out->data(); -#if 1 - auto *ln_2_mean_data = ln_2_mean->data(); - auto *ln_2_var_data = ln_2_var->data(); - auto *dropout_mask_out_data = dropout_mask_out->data(); - auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); -#endif - - // bw output's grad - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); - auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); - auto *d_qkv_bias_out = - ctx.Output(framework::GradVarName("QKVBiasOut")); - auto *d_qktv_out = ctx.Output(framework::GradVarName("QKTVOut")); - auto *d_transpose_out_2 = - ctx.Output(framework::GradVarName("TransposeOut2")); - auto *d_qk_out = ctx.Output(framework::GradVarName("QKOut")); - auto *d_softmax_out = - ctx.Output(framework::GradVarName("SoftmaxOut")); - auto *d_attn_dropout_out = - ctx.Output(framework::GradVarName("AttnDropoutOut")); - auto *d_src_mask_out = - ctx.Output(framework::GradVarName("SrcMaskOut")); - auto *d_fmha_out = ctx.Output(framework::GradVarName("FMHAOut")); - auto *d_out_linear_out = - ctx.Output(framework::GradVarName("OutLinearOut")); -#if 1 - // auto *d_dropout_mask_out = - // ctx.Output(framework::GradVarName("DropoutMaskOut")); - auto *d_bias_dropout_residual_out = - ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); -#endif - auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); - auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); - auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); - auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); - auto *d_transpose_out_2_data = - d_transpose_out_2->mutable_data(ctx.GetPlace()); - auto *d_qk_out_data = d_qk_out->mutable_data(ctx.GetPlace()); - auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); - auto *d_attn_dropout_out_data = - d_attn_dropout_out->mutable_data(ctx.GetPlace()); - auto *d_src_mask_out_data = d_src_mask_out->mutable_data(ctx.GetPlace()); - auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); - auto *d_out_linear_out_data = - d_out_linear_out->mutable_data(ctx.GetPlace()); -#if 1 - // auto *d_dropout_mask_out_data = - // d_dropout_mask_out->mutable_data(ctx.GetPlace()); - auto *d_bias_dropout_residual_out_data = - d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); -#endif - - // bw parameter's grad - auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); - auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); - auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); - auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); - auto *d_out_linear_weight = - ctx.Output(framework::GradVarName("OutLinearW")); - auto *d_out_linear_bias = - ctx.Output(framework::GradVarName("OutLinearBias")); -#if 1 - auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); - auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); -#endif - auto *d_ln_scale_data = - (d_ln_scale == nullptr ? nullptr - : d_ln_scale->mutable_data(ctx.GetPlace())); - auto *d_ln_bias_data = - (d_ln_bias == nullptr ? nullptr - : d_ln_bias->mutable_data(ctx.GetPlace())); - auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); - auto *d_out_linear_weight_data = - d_out_linear_weight->mutable_data(ctx.GetPlace()); - auto *d_out_linear_bias_data = - d_out_linear_bias->mutable_data(ctx.GetPlace()); -#if 1 - auto *d_ln_2_scale_data = - (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( - ctx.GetPlace())); - auto *d_ln_2_bias_data = - (d_ln_2_bias == nullptr ? nullptr - : d_ln_2_bias->mutable_data(ctx.GetPlace())); -#endif - - // get data ptr for qkv part. - const auto input_x_dims = input_x->dims(); - const auto qkv_w_dims = qkv_weight->dims(); - - int batch_size = input_x_dims[0]; - int max_seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int num_head = qkv_w_dims[1]; - int dim_head = qkv_w_dims[2]; - - int bsz_seq = batch_size * max_seq_len; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - Tensor d_residual; - d_residual.Resize(input_x_dims); - T *d_residual_data = d_residual.mutable_data(ctx.GetPlace()); - - bool transA = false; - bool transB = true; - bool compute_bias = true; - auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), - epsilon, bsz_seq, dim_embed); - auto qkv_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); - // fmha - AttnDropoutParam attn_dropout_param( - is_test_1, dropout_implementation_1, attn_dropout_prob, - is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); - auto fmha_ref_compute = - FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, - dim_head, attn_dropout_param); - // out_linear - output_size = hidden_size; - transA = false; - transB = false; - compute_bias = false; - auto out_linear_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); -#if 1 - // bias + dropout + residual + layernorm. - DropoutParam dropout_param2(ctx, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, - ln2epsilon); -#endif -#if 1 - // dout -> dlayernorm_dsrc, dscale, layernorm_dbias - // dlayernorm_dsrc -> dsrc, dbias, dresidual - fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( - ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, - dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, - d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, - d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); -#endif -#if 1 - out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, - d_out_linear_out_data, d_fmha_out_data, - d_out_linear_weight_data, nullptr); -#endif -#if 1 - fmha_ref_compute.ComputeBackward( - *transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, - *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, - d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, - d_transpose_out_2, nullptr, d_qkv_bias_out); - // d_qkv_bias_out->d_qkv_out - // batch_size, seq_len, 3, num_head, head_size - cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, - bsz_seq * 3 * num_head * dim_head * sizeof(T), - cudaMemcpyDeviceToDevice); -#endif -#if 1 - // get qkv - if (pre_layer_norm) { - qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, - d_qkv_bias_out_data, d_ln_out_data, - d_qkv_weight_data, d_qkv_bias_data); - layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, - ln_mean_data, ln_var_data, d_x_data, - d_ln_scale_data, d_ln_bias_data); - } else { - qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data, - d_x_data, d_qkv_weight_data, d_qkv_bias_data); - } - // gradient accumulation: d_x[] + d_residual[] = d_x[] - std::vector ins; - std::vector outs; - ins.emplace_back(&d_residual); - ins.emplace_back(d_x); - outs.emplace_back(d_x); - int elewise_add_axis = -1; - LaunchElementwiseCudaKernel( - ctx.cuda_device_context(), ins, &outs, elewise_add_axis, - AddFunctor()); -#endif } }; @@ -518,7 +225,3 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel); -REGISTER_OP_CUDA_KERNEL(fused_attention_grad, - ops::FusedAttentionGradKernel, - ops::FusedAttentionGradKernel, - ops::FusedAttentionGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 7510ea2a5fe263..6295114fb8ac96 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -105,7 +105,6 @@ def config(self): self.attn_dropout_prob = 0.0 self.weight_attr = None self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim self.key_length, self.value_length = self.query_length, self.query_length @@ -122,7 +121,6 @@ def generate_input_data(self): self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 else: raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") - self.key, self.value = self.query, self.query self.dout = np.random.random((self.batch_size, self.query_length, @@ -131,7 +129,6 @@ def generate_input_data(self): def GetBaselineOut(self): tensor_query = paddle.to_tensor(self.query, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) - #residual = paddle.to_tensor(self.query) residual = tensor_query for i in range(1): @@ -186,10 +183,7 @@ def GetBaselineOut(self): final_out = self.norm1(residual_out) if self.pre_layer_norm: final_out = self.norm2(residual_out) - - paddle.autograd.backward( - [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) - return final_out, tensor_query.grad + return final_out def GetFusedAttentionOut(self): q_proj_weight = paddle.to_tensor( @@ -231,33 +225,21 @@ def GetFusedAttentionOut(self): ln2_epsilon = 1e-05 if attn_mask is not None: - # Support bool or int mask attn_mask = _convert_attention_mask(attn_mask, x.dtype) - for i in range(1): - final_out = F.fused_multihead_attention( - x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, - ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, - qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, - self.attn_dropout_prob, ln2_epsilon) - paddle.autograd.backward( - [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) - - return final_out, x.grad + final_out = F.fused_multihead_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, + out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out def test_fused_attention_op(self): - print( - "self.batch_size, self.query_length, self.embed_dim, self.num_heads, self.head_dim = " - ) - print(self.batch_size, self.query_length, self.embed_dim, - self.num_heads, self.head_dim) - final_out_ref, x_grad_ref = self.GetBaselineOut() - final_out, x_grad = self.GetFusedAttentionOut() + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) - np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -277,26 +259,17 @@ def config(self): self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 - self.weight_attr = None self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim self.key_length, self.value_length = self.query_length, self.query_length def test_fused_attention_op(self): - print( - "self.batch_size, self.query_length, self.embed_dim, self.num_heads, self.head_dim = " - ) - print(self.batch_size, self.query_length, self.embed_dim, - self.num_heads, self.head_dim) - final_out_ref, x_grad_ref = self.GetBaselineOut() - final_out, x_grad = self.GetFusedAttentionOut() + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) - np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) if __name__ == "__main__": From 2c0ab6cc9d2255eee2f6e5759de71223c15b7b4f Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 08:39:03 +0000 Subject: [PATCH 05/29] Remove useless code. --- paddle/fluid/operators/fused/CMakeLists.txt | 6 +++--- .../tests/unittests/test_fused_attention_op.py | 18 ++++-------------- python/paddle/nn/functional/common.py | 11 ----------- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index d16033e9e16153..b993645031054a 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -60,9 +60,6 @@ if (WITH_GPU OR WITH_ROCM) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n") op_library(fused_embedding_eltwise_layernorm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n") - # fused_attention_op - op_library(fused_attention_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") # fusion_group if(NOT APPLE AND NOT WIN32) op_library(fusion_group_op DEPS device_code) @@ -81,5 +78,8 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) + # fused_attention_op + op_library(fused_attention_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") endif() endif() diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 6295114fb8ac96..cb23ad1b1292ac 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -23,7 +23,6 @@ from paddle.fluid.data_feeder import convert_dtype from paddle import tensor from paddle.fluid import layers - import unittest place = paddle.CUDAPlace(0) @@ -136,7 +135,6 @@ def GetBaselineOut(self): if self.pre_layer_norm: ln1_out = self.norm1(tensor_query) - # get q, k, v q = self.q_proj(ln1_out) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) @@ -147,12 +145,10 @@ def GetBaselineOut(self): v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) - # q_out * k^t qk_out = layers.matmul( x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: - # Support bool or int mask attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) attn_mask_out = qk_out + attn_mask softmax_out = F.softmax(attn_mask_out) @@ -169,13 +165,10 @@ def GetBaselineOut(self): else: qktv_out = tensor.matmul(softmax_out, v_out) - # combine heads fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - out_linear_in = tensor.reshape( x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - # project to output out = self.out_proj(out_linear_in) residual_out = residual + self.dropout(out) @@ -208,13 +201,13 @@ def GetFusedAttentionOut(self): q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) - qkv_weight = np.concatenate((q_proj_weight, k_proj_weight)) - qkv_weight = np.concatenate((qkv_weight, v_proj_weight)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) qkv_weight = qkv_weight.reshape( (3, self.num_heads, self.head_dim, self.embed_dim)) - qkv_bias = np.concatenate((q_proj_bias.numpy(), k_proj_bias.numpy())) - qkv_bias = np.concatenate((qkv_bias, v_proj_bias.numpy())) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) x = paddle.to_tensor(self.query, stop_gradient=False) @@ -226,7 +219,6 @@ def GetFusedAttentionOut(self): if attn_mask is not None: attn_mask = _convert_attention_mask(attn_mask, x.dtype) - final_out = F.fused_multihead_attention( x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, @@ -237,7 +229,6 @@ def GetFusedAttentionOut(self): def test_fused_attention_op(self): final_out_ref = self.GetBaselineOut() final_out = self.GetFusedAttentionOut() - np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) @@ -267,7 +258,6 @@ def config(self): def test_fused_attention_op(self): final_out_ref = self.GetBaselineOut() final_out = self.GetFusedAttentionOut() - np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 717a9361d7736d..ab75c12e21ff00 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1564,16 +1564,11 @@ def fused_multihead_attention(x, import paddle """ if in_dygraph_mode(): - ## before integrate kaihuo's code - #ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, src_mask_out, fmha_out, out_linear_out, final_out = _C_ops.fused_attention(x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, out_linear_weight, out_linear_bias, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout) - - ## finally code ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) - #return ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, src_mask_out, fmha_out, out_linear_out, bias_dropout_residual_out, final_out return final_out else: helper = LayerHelper('fused_multihead_attention', **locals()) @@ -1583,7 +1578,6 @@ def fused_multihead_attention(x, 'fused_multihead_attention') check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'fused_multihead_attention') - # set inputs inputs = dict() inputs['X'] = [x] @@ -1600,7 +1594,6 @@ def fused_multihead_attention(x, inputs['Ln2Scale'] = [ln_2_scale] if ln_2_bias: inputs['Ln2Bias'] = [ln_2_bias] - # set attrs attrs = { 'pre_layer_norm': pre_layer_norm, @@ -1609,17 +1602,14 @@ def fused_multihead_attention(x, 'dropout_prob': dropout, 'attn_dropout_prob': attn_dropout } - # set outputs ln_mean_out = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) ln_variance_out = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) ln_out = helper.create_variable_for_type_inference(dtype=dtype) - qkv_out = helper.create_variable_for_type_inference(dtype=dtype) qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) - transpose_out_2 = helper.create_variable_for_type_inference(dtype=dtype) qk_out = helper.create_variable_for_type_inference(dtype=dtype) qktv_out = helper.create_variable_for_type_inference(dtype=dtype) @@ -1628,7 +1618,6 @@ def fused_multihead_attention(x, dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) attn_dropout_out = helper.create_variable_for_type_inference( dtype=dtype) - # todo: stop_gradient? src_mask_out = helper.create_variable_for_type_inference(dtype=dtype) fmha_out = helper.create_variable_for_type_inference(dtype=dtype) out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) From ece3c08b43c76640cb261181caf37d33d55fe3de Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 09:02:56 +0000 Subject: [PATCH 06/29] Remove docs. --- python/paddle/nn/functional/common.py | 43 --------------------------- 1 file changed, 43 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index ab75c12e21ff00..8e33cc7b33608b 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1519,49 +1519,6 @@ def fused_multihead_attention(x, ln2_epsilon=1e-05, name=None): r""" - Fused multihead_attention operator. For each input :math:`X` , - the equation is: - .. math:: - if (pre_layer_norm) - ln_out = layer_norm(X) - qkv_out = get_qkv(ln_out, qkv_weight, qkv_bias) - else - qkv_out = get_qkv(X, qkv_weight, qkv_bias) - fmha_out = fmha(qkv_out, xxx) - out_linear_out = out_linear(fmha_out, out_linear_weight) - out = bias_dropout_residual_layer_norm(out_linear_out, out_linear_bias) - # where :math:`W` is the weight and :math:`b` is the bias. - # If the weight is a 2-D tensor of shape :math:`[in\_features, out\_features]` , - # input should be a multi-dimensional tensor of shape - # :math:`[batch\_size, *, in\_features]` , where :math:`*` means any number of - # additional dimensions. The linear operator multiplies input tensor with - # weight and produces an output tensor of shape :math:`[batch\_size, *, out\_features]` , - # If :math:`bias` is not None, the bias should be a 1-D tensor of shape - # :math:`[out\_features]` and will be added to the output. - Parameters: - x (Tensor): Input tensor with shape [batch\_size, seq\_len, dim_embed]. - The data type should be float16, float32 or float64. - qkv_weight (Tensor): QKV Weight tensor with shape [3 * num_head * dim_head, dim_embed]. - The data type should be float16, float32 or float64. - ## tood: check shape! - out_linear_weight (Tensor): Out_linear Weight tensor with shape [xx, dim_embed]. - The data type should be float16, float32 or float64. - qkv_bias (Tensor, optional): QKV Bias tensor with shape [3 * num_head * dim_head]. - The data type should be float16, float32 or float64. - If it is set to None, no bias will be added to the output units. - ## tood: check shape! - out_linear_bias (Tensor, optional): Out_linear Bias tensor with shape []. - The data type should be float16, float32 or float64. - If it is set to None, no bias will be added to the output units. - name (str, optional): Normally there is no need for user to set this parameter. - For detailed information, please refer to :ref:`api_guide_Name` . - Returns: - Tensor, the shape is :math:`[batch\_size, seq\_len, dim_embed]` and the - data type is the same with input :math:`x` . - Examples: - .. code-block:: python - - import paddle """ if in_dygraph_mode(): ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( From b18b4053c80d244dab15b53072d861db16c6f514 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 22 Sep 2021 13:35:58 +0000 Subject: [PATCH 07/29] Minors. --- .../unittests/test_fused_attention_op.py | 100 ++++++++++-------- 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index cb23ad1b1292ac..5e67fbaf784c00 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -20,6 +20,7 @@ import paddle.nn.functional as F from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.common import Linear, Dropout +import paddle.fluid as fluid from paddle.fluid.data_feeder import convert_dtype from paddle import tensor from paddle.fluid import layers @@ -126,6 +127,7 @@ def generate_input_data(self): self.embed_dim)).astype(self.x_type) def GetBaselineOut(self): + paddle.disable_static() tensor_query = paddle.to_tensor(self.query, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) residual = tensor_query @@ -179,52 +181,58 @@ def GetBaselineOut(self): return final_out def GetFusedAttentionOut(self): - q_proj_weight = paddle.to_tensor( - self.q_proj.weight, stop_gradient=False) - q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) - k_proj_weight = paddle.to_tensor( - self.k_proj.weight, stop_gradient=False) - k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) - v_proj_weight = paddle.to_tensor( - self.v_proj.weight, stop_gradient=False) - v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) - out_linear_weight = paddle.to_tensor( - self.out_proj.weight, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) - - ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) - ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) - ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) - ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) - - q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) - k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) - v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) - qkv_weight = np.concatenate( - (q_proj_weight, k_proj_weight, v_proj_weight)) - qkv_weight = qkv_weight.reshape( - (3, self.num_heads, self.head_dim, self.embed_dim)) - - qkv_bias = np.concatenate( - (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) - - x = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) - qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) - qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) - epsilon = 1e-05 - ln2_epsilon = 1e-05 - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, x.dtype) - final_out = F.fused_multihead_attention( - x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, - ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, - out_linear_bias, attn_mask, self.dropout_prob, - self.attn_dropout_prob, ln2_epsilon) - return final_out + paddle.disable_static() + with fluid.dygraph.guard(fluid.CUDAPlace(0)): + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor( + qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = F.fused_multihead_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, + qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out def test_fused_attention_op(self): final_out_ref = self.GetBaselineOut() From b939159ef594da7642e756e4bc33051ad0b06e15 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 23 Sep 2021 02:11:23 +0000 Subject: [PATCH 08/29] Minors. --- .../paddle/fluid/tests/unittests/test_fused_attention_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 5e67fbaf784c00..19cfcd70a08ed7 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -127,7 +127,7 @@ def generate_input_data(self): self.embed_dim)).astype(self.x_type) def GetBaselineOut(self): - paddle.disable_static() + paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) residual = tensor_query @@ -181,7 +181,7 @@ def GetBaselineOut(self): return final_out def GetFusedAttentionOut(self): - paddle.disable_static() + paddle.disable_static(place=paddle.CUDAPlace(0)) with fluid.dygraph.guard(fluid.CUDAPlace(0)): q_proj_weight = paddle.to_tensor( self.q_proj.weight, stop_gradient=False) From 07fd753e6db88cc3d7703d7f61f9e9824950c46d Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Thu, 23 Sep 2021 11:12:58 +0800 Subject: [PATCH 09/29] Update test_fused_attention_op.py --- python/paddle/fluid/tests/unittests/test_fused_attention_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 19cfcd70a08ed7..87498c61b84e4c 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -26,7 +26,7 @@ from paddle.fluid import layers import unittest -place = paddle.CUDAPlace(0) +# place = paddle.CUDAPlace(0) def _convert_attention_mask(attn_mask, dtype): From b44d8829707b91335d527265a896abb8d2bf020c Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 23 Sep 2021 08:31:13 +0000 Subject: [PATCH 10/29] Remove static construction of python api. --- .../operators/fused/fused_attention_op.cc | 4 +- .../operators/fused/fused_attention_op.cu | 4 +- .../operators/fused/fused_attention_op.h | 4 +- .../unittests/test_fused_attention_op.py | 101 ++++++++---------- python/paddle/nn/functional/common.py | 87 --------------- 5 files changed, 52 insertions(+), 148 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 19f87472eb5184..133e3c875dcba2 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -325,10 +325,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { }); AddComment(R"DOC( -Fused attention: +Fused attention op: if (pre_layernorm) layer_norm; -qkv+bias_add; +compute_qkv + bias_add; fmha; out_linear; bias_add + dropout + residual + layer_norm; diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index b6e385deb8dde3..8210a2a8a61685 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -204,11 +204,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); // fmha_out: [batch_size, seq_len, num_head, head_dim] - // weight: [1024, 1024], [embed_dim, embed_dim] + // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, nullptr, out_linear_out_data, nullptr); - // out = layernorm(residual + dropout(src + bias)) + // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, x_data, out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, diff --git a/paddle/fluid/operators/fused/fused_attention_op.h b/paddle/fluid/operators/fused/fused_attention_op.h index 44d532960b7a5c..032df7818c77d1 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.h +++ b/paddle/fluid/operators/fused/fused_attention_op.h @@ -12,7 +12,5 @@ limitations under the License. */ #pragma once namespace paddle { -namespace operators { -// todo: -} // namespace operators +namespace operators {} // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 87498c61b84e4c..bd6fa8b81d3266 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -20,14 +20,12 @@ import paddle.nn.functional as F from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.common import Linear, Dropout -import paddle.fluid as fluid +# import paddle.fluid as fluid from paddle.fluid.data_feeder import convert_dtype from paddle import tensor from paddle.fluid import layers import unittest -# place = paddle.CUDAPlace(0) - def _convert_attention_mask(attn_mask, dtype): """ @@ -182,57 +180,52 @@ def GetBaselineOut(self): def GetFusedAttentionOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) - with fluid.dygraph.guard(fluid.CUDAPlace(0)): - q_proj_weight = paddle.to_tensor( - self.q_proj.weight, stop_gradient=False) - q_proj_bias = paddle.to_tensor( - self.q_proj.bias, stop_gradient=False) - k_proj_weight = paddle.to_tensor( - self.k_proj.weight, stop_gradient=False) - k_proj_bias = paddle.to_tensor( - self.k_proj.bias, stop_gradient=False) - v_proj_weight = paddle.to_tensor( - self.v_proj.weight, stop_gradient=False) - v_proj_bias = paddle.to_tensor( - self.v_proj.bias, stop_gradient=False) - out_linear_weight = paddle.to_tensor( - self.out_proj.weight, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) - - ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) - ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) - ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) - ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) - - q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) - k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) - v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) - qkv_weight = np.concatenate( - (q_proj_weight, k_proj_weight, v_proj_weight)) - qkv_weight = qkv_weight.reshape( - (3, self.num_heads, self.head_dim, self.embed_dim)) - - qkv_bias = np.concatenate( - (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) - - x = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) - qkv_weight_tensor = paddle.to_tensor( - qkv_weight, stop_gradient=False) - qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) - epsilon = 1e-05 - ln2_epsilon = 1e-05 - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, x.dtype) - final_out = F.fused_multihead_attention( - x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, - ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, - qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, - self.attn_dropout_prob, ln2_epsilon) - return final_out + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = F.fused_multihead_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, + out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out def test_fused_attention_op(self): final_out_ref = self.GetBaselineOut() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 8e33cc7b33608b..b39aef5fd894d2 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1527,93 +1527,6 @@ def fused_multihead_attention(x, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) return final_out - else: - helper = LayerHelper('fused_multihead_attention', **locals()) - dtype = x.dtype - # check dtypes - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'fused_multihead_attention') - check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], - 'fused_multihead_attention') - # set inputs - inputs = dict() - inputs['X'] = [x] - if ln_scale: - inputs['LnScale'] = [ln_scale] - if ln_bias: - inputs['LnBias'] = [ln_bias] - inputs['QKVW'] = [qkv_weight] - inputs['QKVBias'] = [qkv_bias] - inputs['SrcMask'] = src_mask - inputs['OutLinearW'] = [out_linear_weight] - inputs['OutLinearBias'] = [out_linear_bias] - if ln_2_scale: - inputs['Ln2Scale'] = [ln_2_scale] - if ln_2_bias: - inputs['Ln2Bias'] = [ln_2_bias] - # set attrs - attrs = { - 'pre_layer_norm': pre_layer_norm, - 'epsilon': epsilon, - 'ln2_epsilon': ln2_epsilon, - 'dropout_prob': dropout, - 'attn_dropout_prob': attn_dropout - } - # set outputs - ln_mean_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_variance_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_out = helper.create_variable_for_type_inference(dtype=dtype) - qkv_out = helper.create_variable_for_type_inference(dtype=dtype) - qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) - transpose_out_2 = helper.create_variable_for_type_inference(dtype=dtype) - qk_out = helper.create_variable_for_type_inference(dtype=dtype) - qktv_out = helper.create_variable_for_type_inference(dtype=dtype) - softmax_out = helper.create_variable_for_type_inference(dtype=dtype) - attn_dropout_mask_out = helper.create_variable_for_type_inference( - dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) - attn_dropout_out = helper.create_variable_for_type_inference( - dtype=dtype) - src_mask_out = helper.create_variable_for_type_inference(dtype=dtype) - fmha_out = helper.create_variable_for_type_inference(dtype=dtype) - out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) - dropout_mask_out = helper.create_variable_for_type_inference( - dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) - ln_2_mean_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_2_variance_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - bias_dropout_residual_out = helper.create_variable_for_type_inference( - dtype=dtype) - final_out = helper.create_variable_for_type_inference(dtype=dtype) - - helper.append_op( - type='fused_attention', - inputs=inputs, - outputs={ - "LnMean": ln_mean_out, - "LnVariance": ln_variance_out, - "LnOut": ln_out, - "QKVOut": qkv_out, - "QKVBiasOut": qkv_bias_out, - "TransposeOut2": transpose_out_2, - "QKOut": qk_out, - "QKTVOut": qktv_out, - "SoftmaxOut": softmax_out, - "AttnDropoutMaskOut": attn_dropout_mask_out, - "AttnDropoutOut": attn_dropout_out, - "SrcMaskOut": src_mask_out, - "FMHAOut": fmha_out, - "OutLinearOut": out_linear_out, - "DropoutMaskOut": dropout_mask_out, - "Ln2Mean": ln_2_mean_out, - "Ln2Variance": ln_2_variance_out, - "BiasDropoutResidualOut": bias_dropout_residual_out, - 'Y': final_out - }, - attrs=attrs) - return final_out def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): From ff3df46f5e9f405d0e4fdccd4c0aef12114da9e1 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 23 Sep 2021 11:57:49 +0000 Subject: [PATCH 11/29] Modifications accordding to reviews. --- .../operators/fused/fused_attention_op.cc | 64 ++++++++++--------- .../unittests/test_fused_attention_op.py | 30 +-------- python/paddle/nn/layer/transformer.py | 2 +- 3 files changed, 35 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 133e3c875dcba2..8aee7fb0d824cd 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -78,16 +78,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputDim("QKVW"); - PADDLE_ENFORCE_EQ(x_dim.size(), 3, - platform::errors::InvalidArgument( - "The dimensions of QKV_input must be 3" - "(batch_size, seq_len, dim_embed)," - "but received dimensions of" - "Input is [%d]", - x_dim.size())); + PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( + "The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); PADDLE_ENFORCE_EQ(y_dim.size(), 4, platform::errors::InvalidArgument( - "The dimensions of QKV_weight must be 4" + "The dimensions of qkv_weight must be 4" "(3, num_head, dim_head, dim_embed)," "but received dimensions of" "Input is [%d]", @@ -96,8 +95,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "ShapeError: the dimension of x_dim[2] and y_dim[3]" "must be equal. But received: the shape " - "of input X = [%s], and the shape of " - "input Y = [%s]", + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", x_dim, y_dim)); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); @@ -152,13 +151,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor."); AddInput("LnScale", "(optional) Scale is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") + "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); AddInput("LnBias", "(optional) Bias is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") + "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); AddInput("QKVW", "The qkv weight tensor."); AddInput("QKVBias", "The qkv bias tensor."); @@ -168,19 +165,12 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("OutLinearBias", "The out_linear bias tensor."); AddInput("Ln2Scale", "(optional) Scale is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") + "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); AddInput("Ln2Bias", "(optional) Bias is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") + "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - // AddInput("Seed", - // "The seed of dropout op, it has higher priority than the attr - // " - // "fix_seed and seed") - // .AsDispensable(); AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); AddOutput("LnVariance", "Variance of the current mini batch.") .AsIntermediate(); @@ -325,14 +315,26 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { }); AddComment(R"DOC( -Fused attention op: -if (pre_layernorm) - layer_norm; -compute_qkv + bias_add; -fmha; -out_linear; -bias_add + dropout + residual + layer_norm; -)DOC"); + Add fused attention op whose logic is as follows: + // @input: [batch_size, seq_len, 3, num_head, head_dim] + // @final_out: [batch_size, seq_len, num_heads, head_dim] + if (pre_layernorm) + out = layer_norm(input); + out = compute_qkv(out) + bias; + // fmha module + { + out = transpose(out, perm=[2, 0, 3, 1, 4]); + out = q * k^t; + out = attn_mark + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + + } + out = out_linear(out); + final_out = layer_norm(residual + dropout(bias + out)); + )DOC"); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index bd6fa8b81d3266..bf26e05c844e49 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -20,40 +20,12 @@ import paddle.nn.functional as F from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.common import Linear, Dropout -# import paddle.fluid as fluid -from paddle.fluid.data_feeder import convert_dtype +from paddle.nn.layer.transformer import _convert_attention_mask from paddle import tensor from paddle.fluid import layers import unittest -def _convert_attention_mask(attn_mask, dtype): - """ - Convert the attention mask to the target dtype we expect. - Parameters: - attn_mask (Tensor, optional): A tensor used in multi-head attention - to prevents attention to some unwanted positions, usually the - paddings or the subsequent positions. It is a tensor with shape - broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when - nothing wanted or needed to be prevented attention to. Default None. - dtype (VarType): The target type of `attn_mask` we expect. - Returns: - Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`. - """ - if attn_mask is not None and attn_mask.dtype != dtype: - attn_mask_dtype = convert_dtype(attn_mask.dtype) - if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: - attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9 - else: - attn_mask = paddle.cast(attn_mask, dtype) - return attn_mask - - @unittest.skipIf(not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA") class TestFusedAttentionOp(unittest.TestCase): diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index eacf5aac9daa9f..36bc83647965e5 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -26,7 +26,7 @@ from ...fluid import layers from .. import Layer, LayerList from ...framework import ParamAttr -from ...fluid.data_feeder import convert_dtype +from paddle.fluid.data_feeder import convert_dtype __all__ = [] From 8a4c2a81aa19f93e49614bb5a7b18d920cb2d963 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Sun, 26 Sep 2021 14:17:49 +0000 Subject: [PATCH 12/29] Modifications accordding to Xreki's review. --- .../operators/fused/fused_attention_op.cc | 24 +- .../operators/fused/fused_attention_op.cu | 64 ++- .../unittests/test_fused_attention_op.py | 370 ++++++++++-------- python/paddle/nn/functional/__init__.py | 1 - python/paddle/nn/functional/common.py | 27 -- python/paddle/nn/layer/fused_transformer.py | 312 ++++++++++----- 6 files changed, 469 insertions(+), 329 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 8aee7fb0d824cd..e468fe0f3f7dc2 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -116,7 +116,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - if (ctx->Attrs().Get("is_test1") == false) { + if (ctx->Attrs().Get("attn_dropout_is_test") == false) { ctx->SetOutputDim("AttnDropoutMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); } @@ -221,20 +221,20 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { platform::errors::InvalidArgument( "'attn_dropout_prob' must be between 0.0 and 1.0.")); }); - AddAttr("is_test1", + AddAttr("attn_dropout_is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("fix_seed1", + AddAttr("attn_dropout_fix_seed", "A flag indicating whether to use a fixed seed to generate " "random mask. NOTE: DO NOT set this flag to true in " "training. Setting this flag to true is only useful in " "unittest or for debug that always the same output units " "will be dropped.") .SetDefault(true); - AddAttr("seed1", "Dropout random seed.").SetDefault(0); + AddAttr("attn_dropout_seed_val", "Dropout random seed.").SetDefault(0); AddAttr( - "dropout_implementation1", + "attn_dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" "There are two kinds of ways to implement dropout" "(the mask below is a tensor have the same shape with input" @@ -281,19 +281,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" - "There are two kinds of ways to implement dropout" - "(the mask below is a tensor have the same shape with input" - "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" - "1. downgrade_in_infer(default), downgrade the outcome at inference " - "time" - " train: out = input * mask" - " inference: out = input * (1.0 - dropout_prob)" - "2. upscale_in_train, upscale the outcome at training time, do nothing " - "in inference" - " train: out = input * mask / ( 1.0 - dropout_prob )" - " inference: out = input" - " dropout op can be removed from the program. the program will be " - "efficient") + "The meaning is the same as \"attn_dropout_implementation\" attribute.") .SetDefault("downgrade_in_infer") .AddCustomChecker([](const std::string &type) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 8210a2a8a61685..8cea767f9745e1 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -9,31 +9,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef __NVCC__ -#include -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif +#include "paddle/fluid/operators/fused/fused_attention_op.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include +#include #include "paddle/fluid/platform/cuda_device_function.h" - -#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" -#endif -#ifdef PADDLE_WITH_HIP -#include "paddle/fluid/platform/miopen_helper.h" -#endif -#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/fused/fused_attention_op.h" - #include "paddle/fluid/operators/fused/attention_layer_norm.h" #include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" @@ -90,14 +77,16 @@ class FusedAttentionOpKernel : public framework::OpKernel { const float ln2epsilon = ctx.Attr("ln2epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); - bool is_test_1 = ctx.Attr("is_test1"); - auto &dropout_implementation_1 = - ctx.Attr("dropout_implementation1"); - bool is_upscale_in_train_1 = - (dropout_implementation_1 == "upscale_in_train"); - auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - bool is_fix_seed_1 = ctx.Attr("fix_seed1"); - int seed_val_1 = ctx.Attr("seed1"); + bool attn_dropout_is_test = ctx.Attr("attn_dropout_is_test"); + auto &attn_dropout_implementation = + ctx.Attr("attn_dropout_implementation"); + bool attn_dropout_is_upscale_in_train = + (attn_dropout_implementation == "upscale_in_train"); + auto *attn_dropout_seed = ctx.HasInput("AttnDropoutSeed") + ? ctx.Input("AttnDropoutSeed") + : nullptr; + bool attn_dropout_fix_seed = ctx.Attr("attn_dropout_fix_seed"); + int attn_dropout_seed_val = ctx.Attr("attn_dropout_seed_val"); // final output. auto *out = ctx.Output("Y"); @@ -119,7 +108,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. - auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); @@ -162,29 +150,25 @@ class FusedAttentionOpKernel : public framework::OpKernel { int output_size = 3 * hidden_size; int input_size = dim_embed; - bool transA = false; - bool transB = true; - bool compute_bias = true; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); - auto qkv_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + // (transA, transB, compute_bias) = (false, true, true) + auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, + bsz_seq, output_size, input_size, true); AttnDropoutParam attn_dropout_param( - is_test_1, dropout_implementation_1, attn_dropout_prob, - is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + attn_dropout_is_test, attn_dropout_implementation, attn_dropout_prob, + attn_dropout_is_upscale_in_train, attn_dropout_fix_seed, + attn_dropout_seed_val, attn_dropout_seed); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; - transA = false; - transB = false; - compute_bias = false; + // (transA, transB, compute_bias) = (false, false, false) auto out_linear_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, + output_size, input_size, false); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index bf26e05c844e49..980ec91f28d614 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -18,65 +18,130 @@ import paddle.nn as nn import paddle.fluid.core as core import paddle.nn.functional as F -from paddle.nn.layer.norm import LayerNorm -from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.fused_transformer import FusedMultiHeadAttention from paddle.nn.layer.transformer import _convert_attention_mask from paddle import tensor from paddle.fluid import layers +from paddle.static import Program, program_guard import unittest +from op_test import OpTest + + +def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, + query, attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, + qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, + attn_dropout_prob, dropout_prob): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(query, stop_gradient=False) + attn_mask = paddle.to_tensor(attn_mask, stop_gradient=False) + residual = tensor_query + ln_scale = paddle.to_tensor(ln_scale) + ln_bias = paddle.to_tensor(ln_bias) + ln_2_scale = paddle.to_tensor(ln_2_scale) + ln_2_bias = paddle.to_tensor(ln_2_bias) + out_linear_weight = paddle.to_tensor(out_linear_weight) + out_linear_bias = paddle.to_tensor(out_linear_bias) + + # qkv_weight: [3, num_heads, self.head_dim, embed_dim] + q_weight = qkv_weight[0:1, ::] + k_weight = qkv_weight[1:2, ::] + v_weight = qkv_weight[2:3, ::] + q_weight = q_weight.reshape(num_heads * head_dim, embed_dim) + k_weight = k_weight.reshape(num_heads * head_dim, embed_dim) + v_weight = v_weight.reshape(num_heads * head_dim, embed_dim) + q_weight = paddle.to_tensor(q_weight.transpose((1, 0))) + k_weight = paddle.to_tensor(k_weight.transpose((1, 0))) + v_weight = paddle.to_tensor(v_weight.transpose((1, 0))) + # qkv_bias: [3, num_heads, self.head_dim] + q_bias = qkv_bias[0:1, ::] + q_bias = q_bias.reshape(num_heads * head_dim) + k_bias = qkv_bias[1:2, ::] + k_bias = k_bias.reshape(num_heads * head_dim) + v_bias = qkv_bias[2:3, ::] + v_bias = v_bias.reshape(num_heads * head_dim) + q_bias = paddle.to_tensor(q_bias) + k_bias = paddle.to_tensor(k_bias) + v_bias = paddle.to_tensor(v_bias) + + for i in range(1): + ln1_out = tensor_query + if pre_layer_norm: + ln1_out = F.layer_norm(tensor_query, embed_dim, ln_scale, ln_bias) + + q = F.linear(ln1_out, q_weight, q_bias) + q = tensor.reshape(x=q, shape=[0, 0, num_heads, head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = F.linear(ln1_out, k_weight, k_bias) + v = F.linear(ln1_out, v_weight, v_bias) + k = tensor.reshape(x=k, shape=[0, 0, num_heads, head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, num_heads, head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if attn_dropout_prob: + dropout_out = F.dropout( + softmax_out, + attn_dropout_prob, + training=training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = F.linear(out_linear_in, out_linear_weight, out_linear_bias) + + residual_out = residual + F.dropout( + out, dropout_prob, training=training, mode="upscale_in_train") + #if not pre_layer_norm: + final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) + #if pre_layer_norm: + # final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) + return final_out @unittest.skipIf(not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA") -class TestFusedAttentionOp(unittest.TestCase): +class TestFusedAttentionOpFp32(OpTest): def setUp(self): self.config() + self.common_config() self.generate_input_data() - paddle.set_default_dtype(self.x_type) - self.q_proj = Linear( - self.embed_dim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.k_proj = Linear( - self.kdim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.v_proj = Linear( - self.vdim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.out_proj = Linear( - self.embed_dim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - paddle.set_default_dtype(np.float32) - self.norm1 = LayerNorm(self.embed_dim) - self.norm2 = LayerNorm(self.embed_dim) - paddle.set_default_dtype(self.x_type) - self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") def config(self): self.x_type = np.float32 - self.attn_mask_type = np.float64 - self.pre_layer_norm = True - self.training = True + self.pre_layer_norm = True self.batch_size = 8 self.query_length = 128 self.head_dim = 64 self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads + def common_config(self): + self.__class__.op_type = "fused_attention" + paddle.set_default_dtype(self.x_type) + self.embed_dim = self.head_dim * self.num_heads + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + self.attn_mask_type = np.float64 + self.training = True + self.need_weight = False self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.weight_attr = None self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, @@ -93,134 +158,75 @@ def generate_input_data(self): raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") self.key, self.value = self.query, self.query - self.dout = np.random.random((self.batch_size, self.query_length, - self.embed_dim)).astype(self.x_type) - - def GetBaselineOut(self): + def compute_result(self): paddle.disable_static(place=paddle.CUDAPlace(0)) - tensor_query = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) - residual = tensor_query - - for i in range(1): - ln1_out = tensor_query - if self.pre_layer_norm: - ln1_out = self.norm1(tensor_query) - - q = self.q_proj(ln1_out) - q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) - q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) - k = self.k_proj(ln1_out) - v = self.v_proj(ln1_out) - k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) - k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) - v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) - v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) - - qk_out = layers.matmul( - x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) - attn_mask_out = qk_out + attn_mask - softmax_out = F.softmax(attn_mask_out) - else: - softmax_out = F.softmax(qk_out) - - if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train") - qktv_out = tensor.matmul(dropout_out, v_out) - else: - qktv_out = tensor.matmul(softmax_out, v_out) - - fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - out_linear_in = tensor.reshape( - x=fmha_out, - shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - out = self.out_proj(out_linear_in) - - residual_out = residual + self.dropout(out) - if not self.pre_layer_norm: - final_out = self.norm1(residual_out) - if self.pre_layer_norm: - final_out = self.norm2(residual_out) - return final_out - - def GetFusedAttentionOut(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - q_proj_weight = paddle.to_tensor( - self.q_proj.weight, stop_gradient=False) - q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) - k_proj_weight = paddle.to_tensor( - self.k_proj.weight, stop_gradient=False) - k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) - v_proj_weight = paddle.to_tensor( - self.v_proj.weight, stop_gradient=False) - v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) - out_linear_weight = paddle.to_tensor( - self.out_proj.weight, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) - - ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) - ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) - ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) - ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) - - q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) - k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) - v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) - qkv_weight = np.concatenate( - (q_proj_weight, k_proj_weight, v_proj_weight)) - qkv_weight = qkv_weight.reshape( - (3, self.num_heads, self.head_dim, self.embed_dim)) - - qkv_bias = np.concatenate( - (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) - - x = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) - qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) - qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) - epsilon = 1e-05 - ln2_epsilon = 1e-05 - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, x.dtype) - final_out = F.fused_multihead_attention( - x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, - ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, - out_linear_bias, attn_mask, self.dropout_prob, - self.attn_dropout_prob, ln2_epsilon) - return final_out + fused_attn = FusedMultiHeadAttention( + self.embed_dim, self.num_heads, self.dropout_prob, + self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, + self.need_weight, self.weight_attr, self.bias_attr) + out = fused_attn( + paddle.to_tensor(self.query), + paddle.to_tensor(self.query), + paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) + ref_out = GetBaselineOut(self.pre_layer_norm, self.training, + self.embed_dim, self.num_heads, self.head_dim, + self.query, self.attn_mask, + fused_attn.ln_scale.numpy(), + fused_attn.ln_bias.numpy(), + fused_attn.ln_2_scale.numpy(), + fused_attn.ln_2_bias.numpy(), + fused_attn.qkv_weight.numpy(), + fused_attn.qkv_bias.numpy(), + fused_attn.out_linear_weight.numpy(), + fused_attn.out_linear_bias.numpy(), + self.attn_dropout_prob, self.dropout_prob) + return ref_out, out def test_fused_attention_op(self): - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedAttentionOut() - np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + ref_out, out = self.compute_result() + self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) @unittest.skipIf(not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA") -class TestFusedAttentionOpFp16(TestFusedAttentionOp): +class TestFusedAttentionOpFp16(TestFusedAttentionOpFp32): + def setUp(self): + self.config() + self.common_config() + self.generate_input_data() + def config(self): self.x_type = np.float16 - self.attn_mask_type = np.float64 self.pre_layer_norm = True - self.training = True - self.batch_size = 8 self.query_length = 128 self.head_dim = 64 self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads + def test_fused_attention_op(self): + ref_out, out = self.compute_result() + self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-2)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedAttentionStaticAPI(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + + def config(self): + self.x_type = np.float32 + self.__class__.op_type = "fused_attention" + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + self.need_weight = False + self.batch_size = 1 + self.query_length = 2 + self.head_dim = 2 + self.num_heads = 2 + self.embed_dim = self.head_dim * self.num_heads self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.weight_attr = None @@ -228,11 +234,69 @@ def config(self): self.kdim, self.vdim = self.embed_dim, self.embed_dim self.key_length, self.value_length = self.query_length, self.query_length - def test_fused_attention_op(self): - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedAttentionOut() - np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.key, self.value = self.query, self.query + + def run_static(self): + fused_attn = FusedMultiHeadAttention( + self.embed_dim, self.num_heads, self.dropout_prob, + self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, + self.need_weight, self.weight_attr, self.bias_attr) + + x = paddle.static.data( + name='X', + shape=[self.batch_size, self.query_length, self.embed_dim], + dtype=self.x_type) + attn_mask = paddle.static.data( + name='SrcMask', + shape=[ + self.batch_size, self.num_heads, self.query_length, + self.key_length + ], + dtype=self.attn_mask_type) + final_out = fused_attn(x, x, x, attn_mask) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.out_linear_weight, fused_attn.out_linear_bias, + fused_attn.ln_scale, fused_attn.ln_bias, fused_attn.ln_2_scale, + fused_attn.ln_2_bias + ]) + return out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(Program()): + out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = self.run_static( + ) + ref_out = GetBaselineOut( + self.pre_layer_norm, self.training, self.embed_dim, self.num_heads, + self.head_dim, self.query, self.attn_mask, ln_scale, ln_bias, + ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, out_linear_weight, + out_linear_bias, self.attn_dropout_prob, self.dropout_prob) + + self.assertTrue( + np.allclose( + np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5)) if __name__ == "__main__": diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d8bec647f2c54c..7965b362b9c55a 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,7 +60,6 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 -from .common import fused_multihead_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b39aef5fd894d2..fcfbea438d7cca 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1502,33 +1502,6 @@ def linear(x, weight, bias=None, name=None): return res -def fused_multihead_attention(x, - qkv_weight, - out_linear_weight, - pre_layer_norm=False, - ln_scale=None, - ln_bias=None, - ln_2_scale=None, - ln_2_bias=None, - epsilon=1e-05, - qkv_bias=None, - out_linear_bias=None, - src_mask=None, - dropout=0., - attn_dropout=0., - ln2_epsilon=1e-05, - name=None): - r""" - """ - if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, - out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, - 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) - return final_out - - def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index 0084f7ff339df3..7f949e49e27345 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -12,115 +12,247 @@ # See the License for the specific language governing permissions and # limitations under the License. - -class FusedMultiHeadAttention(Layer): +import paddle +from ...fluid.layers import core +from ...fluid.framework import in_dygraph_mode +from paddle import _C_ops +from paddle.fluid.layer_helper import LayerHelper +import copy +from .. import functional as F +from paddle.nn import Layer +from ...framework import ParamAttr +from ...framework import get_default_dtype, set_default_dtype +from paddle.nn.layer.transformer import _convert_attention_mask +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype +from ..initializer import Constant + +import collections + + +def fused_multihead_attention(x, + qkv_weight, + out_linear_weight, + pre_layer_norm=False, + ln_scale=None, + ln_bias=None, + ln_2_scale=None, + ln_2_bias=None, + epsilon=1e-05, + qkv_bias=None, + out_linear_bias=None, + src_mask=None, + dropout=0., + attn_dropout=0., + ln2_epsilon=1e-05, + name=None): + r""" """ - Attention mapps queries and a set of key-value pairs to outputs, and - Multi-Head Attention performs multiple parallel attention to jointly attending - to information from different representation subspaces. - - Please refer to `Attention Is All You Need `_ - for more details. - - Parameters: - embed_dim (int): The expected feature size in the input and output. - num_heads (int): The number of heads in multi-head attention. - dropout (float, optional): The dropout probability used on attention - weights to drop some attention targets. 0 for no dropout. Default 0 - kdim (int, optional): The feature size in key. If None, assumed equal to - `embed_dim`. Default None. - vdim (int, optional): The feature size in value. If None, assumed equal to - `embed_dim`. Default None. - need_weights (bool, optional): Indicate whether to return the attention - weights. Default False. - weight_attr(ParamAttr, optional): To specify the weight parameter property. - Default: None, which means the default weight parameter property is used. - See usage for details in :code:`ParamAttr` . - bias_attr (ParamAttr|bool, optional): To specify the bias parameter property. - Default: None, which means the default bias parameter property is used. - If it is set to False, this layer will not have trainable bias parameter. - See usage for details in :code:`ParamAttr` . - - Examples: - - .. code-block:: python + if in_dygraph_mode(): + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, + out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + return final_out + else: + helper = LayerHelper('FusedMultiHeadAttention', **locals()) + dtype = x.dtype + # check dtypes + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'FusedMultiHeadAttention') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], + 'FusedMultiHeadAttention') + + # set inputs + inputs = dict() + inputs['X'] = [x] + if ln_scale: + inputs['LnScale'] = [ln_scale] + if ln_bias: + inputs['LnBias'] = [ln_bias] + inputs['QKVW'] = [qkv_weight] + inputs['QKVBias'] = [qkv_bias] + inputs['SrcMask'] = src_mask + inputs['OutLinearW'] = [out_linear_weight] + inputs['OutLinearBias'] = [out_linear_bias] + if ln_2_scale: + inputs['Ln2Scale'] = [ln_2_scale] + if ln_2_bias: + inputs['Ln2Bias'] = [ln_2_bias] + + # set attrs + attrs = { + 'pre_layer_norm': pre_layer_norm, + 'epsilon': epsilon, + 'ln2_epsilon': ln2_epsilon, + 'dropout_prob': dropout, + 'attn_dropout_prob': attn_dropout + } + + # set outputs + ln_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_out = helper.create_variable_for_type_inference(dtype=dtype) + + qkv_out = helper.create_variable_for_type_inference(dtype=dtype) + qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) + + transpose_out_2 = helper.create_variable_for_type_inference(dtype=dtype) + qk_out = helper.create_variable_for_type_inference(dtype=dtype) + qktv_out = helper.create_variable_for_type_inference(dtype=dtype) + softmax_out = helper.create_variable_for_type_inference(dtype=dtype) + attn_dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + attn_dropout_out = helper.create_variable_for_type_inference( + dtype=dtype) + src_mask_out = helper.create_variable_for_type_inference(dtype=dtype) + fmha_out = helper.create_variable_for_type_inference(dtype=dtype) + out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) + dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + ln_2_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_2_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + bias_dropout_residual_out = helper.create_variable_for_type_inference( + dtype=dtype) + final_out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='fused_attention', + inputs=inputs, + outputs={ + "LnMean": ln_mean_out, + "LnVariance": ln_variance_out, + "LnOut": ln_out, + "QKVOut": qkv_out, + "QKVBiasOut": qkv_bias_out, + "TransposeOut2": transpose_out_2, + "QKOut": qk_out, + "QKTVOut": qktv_out, + "SoftmaxOut": softmax_out, + "AttnDropoutMaskOut": attn_dropout_mask_out, + "AttnDropoutOut": attn_dropout_out, + "SrcMaskOut": src_mask_out, + "FMHAOut": fmha_out, + "OutLinearOut": out_linear_out, + "DropoutMaskOut": dropout_mask_out, + "Ln2Mean": ln_2_mean_out, + "Ln2Variance": ln_2_variance_out, + "BiasDropoutResidualOut": bias_dropout_residual_out, + 'Y': final_out + }, + attrs=attrs) + return final_out - import paddle - # encoder input: [batch_size, sequence_length, d_model] - query = paddle.rand((2, 4, 128)) - # self attention mask: [batch_size, num_heads, query_len, query_len] - attn_mask = paddle.rand((2, 2, 4, 4)) - multi_head_attn = paddle.nn.MultiHeadAttention(128, 2) - output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] +class FusedMultiHeadAttention(Layer): + """ """ - Cache = collections.namedtuple("Cache", ["k", "v"]) - StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + # todo (@limin, do we need cache in FusedMultiHeadAttention layer?) + # Cache = collections.namedtuple("Cache", ["k", "v"]) + # StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) def __init__(self, embed_dim, num_heads, dropout=0., + attn_dropout=0., kdim=None, vdim=None, + normalize_before=False, need_weights=False, weight_attr=None, - bias_attr=None): + bias_attr=None, + name=None): super(FusedMultiHeadAttention, self).__init__() - raise NotImplementedError() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(num_heads)) + + attn_dropout = dropout if attn_dropout is None else attn_dropout + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + + ## linear parameters. + self.qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.out_linear_weight = self.create_parameter( + shape=[embed_dim, embed_dim], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.out_linear_bias = self.create_parameter( + shape=[embed_dim], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + + if get_default_dtype() == 'float16': + set_default_dtype('float32') + ## layer_norm parameters. + self.ln_scale = self.create_parameter( + attr=self._weight_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.ln_bias = self.create_parameter( + attr=self._bias_attr, shape=[embed_dim], is_bias=True) + self.ln_2_scale = self.create_parameter( + attr=self._weight_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.ln_2_bias = self.create_parameter( + attr=self._bias_attr, shape=[embed_dim], is_bias=True) + if get_default_dtype() == 'float16': + set_default_dtype('float16') + + ## dropout parameters + self.dropout = dropout + self.attn_dropout = attn_dropout + + self.name = name def forward(self, query, key=None, value=None, attn_mask=None, cache=None): """ - Applies multi-head attention to map queries and a set of key-value pairs - to outputs. - Parameters: - query (Tensor): The queries for multi-head attention. It is a - tensor with shape `[batch_size, query_length, embed_dim]`. The - data type should be float32 or float64. - key (Tensor, optional): The keys for multi-head attention. It is - a tensor with shape `[batch_size, key_length, kdim]`. The - data type should be float32 or float64. If None, use `query` as - `key`. Default None. - value (Tensor, optional): The values for multi-head attention. It - is a tensor with shape `[batch_size, value_length, vdim]`. - The data type should be float32 or float64. If None, use `query` as - `value`. Default None. - attn_mask (Tensor, optional): A tensor used in multi-head attention - to prevents attention to some unwanted positions, usually the - paddings or the subsequent positions. It is a tensor with shape - broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when - nothing wanted or needed to be prevented attention to. Default None. - cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): - It is a namedtuple with `k` and `v` as fields, and stores tensors - shaped `[batch_size, num_heads, length, embed_dim]` which are results - of linear projection, reshape and transpose calculations in - MultiHeadAttention. If it is an instance of `Cache`, `k` and `v` - fields reserve intermediate results of previous positions, which - mostly used for decoder self attention. If it is an instance of - `StaticCache`, `key` and `value` args would be ignored, `k` and - `v` fields would be used as calculated results on `key` and - `value`, which mostly used for decoder-encoder cross attention. - It is only used for inference and should be None for training. - Default None. - Returns: - Tensor|tuple: It is a tensor that has the same shape and data type \ - as `query`, representing attention output. Or a tuple if \ - `need_weights` is True or `cache` is not None. If `need_weights` \ - is True, except for attention output, the tuple also includes \ - the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \ - If `cache` is not None, the tuple then includes the new cache \ - having the same type as `cache`, and if it is `StaticCache`, it \ - is same as the input `cache`, if it is `Cache`, the new cache \ - reserves tensors concatanating raw tensors with intermediate \ - results of current query. """ - raise NotImplementedError() + if attn_mask is not None: + # Support bool or int mask + attn_mask = _convert_attention_mask(attn_mask, query.dtype) + out = fused_multihead_attention( + x=query, + qkv_weight=self.qkv_weight, + out_linear_weight=self.out_linear_weight, + pre_layer_norm=self.normalize_before, + ln_scale=self.ln_scale, + ln_bias=self.ln_bias, + ln_2_scale=self.ln_2_scale, + ln_2_bias=self.ln_2_bias, + epsilon=1e-05, + qkv_bias=self.qkv_bias, + out_linear_bias=self.out_linear_bias, + src_mask=attn_mask, + dropout=self.dropout, + attn_dropout=self.attn_dropout, + ln2_epsilon=1e-05) + return out class FusedFeedForward(Layer): From 739d9ca7dd039472ca35f83e041b3b6130016d24 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Mon, 27 Sep 2021 02:39:27 +0000 Subject: [PATCH 13/29] Modifications unittest/cmakefile.txt. --- .../fluid/tests/unittests/CMakeLists.txt | 4 + .../unittests/test_fused_attention_op.py | 100 ------------------ 2 files changed, 4 insertions(+), 100 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 3496021892f342..00a86a917b7030 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -88,6 +88,10 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() +if(NOT WITH_GPU) + LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) +endif() + if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) LIST(REMOVE_ITEM TEST_OPS test_c_concat) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 980ec91f28d614..277d68dada8bb5 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np - import paddle import paddle.nn as nn import paddle.fluid.core as core @@ -105,15 +104,10 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, residual_out = residual + F.dropout( out, dropout_prob, training=training, mode="upscale_in_train") - #if not pre_layer_norm: final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) - #if pre_layer_norm: - # final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) return final_out -@unittest.skipIf(not core.is_compiled_with_cuda(), - "Paddle core is not compiled with CUDA") class TestFusedAttentionOpFp32(OpTest): def setUp(self): self.config() @@ -122,7 +116,6 @@ def setUp(self): def config(self): self.x_type = np.float32 - self.pre_layer_norm = True self.batch_size = 8 self.query_length = 128 @@ -187,8 +180,6 @@ def test_fused_attention_op(self): self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) -@unittest.skipIf(not core.is_compiled_with_cuda(), - "Paddle core is not compiled with CUDA") class TestFusedAttentionOpFp16(TestFusedAttentionOpFp32): def setUp(self): self.config() @@ -208,96 +199,5 @@ def test_fused_attention_op(self): self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-2)) -@unittest.skipIf(not core.is_compiled_with_cuda(), - "Paddle core is not compiled with CUDA") -class TestFusedAttentionStaticAPI(OpTest): - def setUp(self): - self.config() - self.generate_input_data() - - def config(self): - self.x_type = np.float32 - self.__class__.op_type = "fused_attention" - self.attn_mask_type = np.float64 - self.pre_layer_norm = True - self.training = True - self.need_weight = False - self.batch_size = 1 - self.query_length = 2 - self.head_dim = 2 - self.num_heads = 2 - self.embed_dim = self.head_dim * self.num_heads - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None - self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length - - def generate_input_data(self): - self.query = np.random.rand(self.batch_size, self.query_length, - self.embed_dim).astype(self.x_type) - self.attn_mask = np.ones( - (self.batch_size, self.num_heads, self.query_length, - self.key_length), - dtype=self.attn_mask_type) - if self.attn_mask_type == np.int64: - self.attn_mask = np.tril(self.attn_mask) - elif self.attn_mask_type == np.float64: - self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 - else: - raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") - self.key, self.value = self.query, self.query - - def run_static(self): - fused_attn = FusedMultiHeadAttention( - self.embed_dim, self.num_heads, self.dropout_prob, - self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, - self.need_weight, self.weight_attr, self.bias_attr) - - x = paddle.static.data( - name='X', - shape=[self.batch_size, self.query_length, self.embed_dim], - dtype=self.x_type) - attn_mask = paddle.static.data( - name='SrcMask', - shape=[ - self.batch_size, self.num_heads, self.query_length, - self.key_length - ], - dtype=self.attn_mask_type) - final_out = fused_attn(x, x, x, attn_mask) - - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( - paddle.static.default_main_program(), - feed={"X": self.query, - "SrcMask": self.attn_mask}, - fetch_list=[ - final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, - fused_attn.out_linear_weight, fused_attn.out_linear_bias, - fused_attn.ln_scale, fused_attn.ln_bias, fused_attn.ln_2_scale, - fused_attn.ln_2_bias - ]) - return out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias - - def test_static_api(self): - paddle.enable_static() - with paddle.static.program_guard(Program()): - out, qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = self.run_static( - ) - ref_out = GetBaselineOut( - self.pre_layer_norm, self.training, self.embed_dim, self.num_heads, - self.head_dim, self.query, self.attn_mask, ln_scale, ln_bias, - ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, out_linear_weight, - out_linear_bias, self.attn_dropout_prob, self.dropout_prob) - - self.assertTrue( - np.allclose( - np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5)) - - if __name__ == "__main__": unittest.main() From 1d9e1251e963ff54f6252ad82e1b36264587bb72 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Mon, 27 Sep 2021 12:55:19 +0000 Subject: [PATCH 14/29] Fetch new fused_dropout_helper.h from #35843. --- paddle/fluid/operators/dropout_impl.cu.h | 26 +-- paddle/fluid/operators/dropout_impl_util.h | 53 ++++++ .../operators/fused/fused_attention_op.cc | 8 +- .../operators/fused/fused_dropout_helper.h | 164 ++++++++---------- python/paddle/nn/layer/fused_transformer.py | 92 ---------- 5 files changed, 132 insertions(+), 211 deletions(-) create mode 100644 paddle/fluid/operators/dropout_impl_util.h diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 4261a5f2534c85..695d29b294a51a 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/gpu_launch_config.h" @@ -196,28 +197,9 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, config.thread_per_block.x * vec_size) + 1) * vec_size; - int device_id = - BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - - if ((seed) && platform::is_gpu_place(seed->place())) { - framework::Tensor seed_cpu_tensor; - TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); - seed_data = static_cast(seed_cpu_tensor.data()[0]); - increment = offset; - } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { - auto seed_offset = gen_cuda->IncrementOffset(offset); - seed_data = seed_offset.first; - increment = seed_offset.second; - } else { - if (seed) { - seed_data = *(seed->data()); - } else { - std::random_device rnd; - seed_data = is_fix_seed ? seed_val : rnd(); - } - increment = offset; - } + + GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset, + &seed_data, &increment); #ifdef __HIPCC__ if (vec_size == 4 && size % 4 == 0) { diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h new file mode 100644 index 00000000000000..a7188efe7139c7 --- /dev/null +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/tensor_util.h" + +namespace paddle { +namespace operators { + +inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* seed, + const bool is_fix_seed, const int seed_val, + const int offset, uint64_t* seed_data, + uint64_t* increment) { + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + + if ((seed) && platform::is_gpu_place(seed->place())) { + framework::Tensor seed_cpu_tensor; + TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); + *seed_data = static_cast(seed_cpu_tensor.data()[0]); + *increment = offset; + } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { + auto seed_offset = gen_cuda->IncrementOffset(offset); + *seed_data = seed_offset.first; + *increment = seed_offset.second; + } else { + if (seed) { + *seed_data = *(seed->data()); + } else { + std::random_device rnd; + *seed_data = is_fix_seed ? seed_val : rnd(); + } + *increment = offset; + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index e468fe0f3f7dc2..b481dc4be70af6 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -129,7 +129,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); - if (ctx->Attrs().Get("is_test") == false) { + if (ctx->Attrs().Get("dropout_is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); @@ -266,18 +266,18 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "'dropout_prob' must be between 0.0 and 1.0.")); }); - AddAttr("is_test", + AddAttr("dropout_is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("fix_seed", + AddAttr("dropout_fix_seed", "A flag indicating whether to use a fixed seed to generate " "random mask. NOTE: DO NOT set this flag to true in " "training. Setting this flag to true is only useful in " "unittest or for debug that always the same output units " "will be dropped.") .SetDefault(true); - AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr("dropout_seed", "Dropout random seed.").SetDefault(0); AddAttr( "dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 826a06e9bfe00b..90d0e2ae0a5a14 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -15,18 +15,21 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/functors.h" -#include "paddle/fluid/operators/math/math_function.h" -#ifdef __NVCC__ +#include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" -#endif +#include "paddle/fluid/operators/math/functors.h" namespace paddle { namespace operators { +/** + * Dropout will be called twice in FFN. So there will be two dropout parameters. + * The DropoutParam will be used in the fused_dropout_act_bias, + * fused_residual_dropout_bias(pre_layer_norm=ture) or + * fused_layernorm_residual_dropout_bias(pre_layer_norm=false). + */ struct DropoutParam { uint64_t seed; float dropout_prob; @@ -34,65 +37,53 @@ struct DropoutParam { bool is_test; bool fix_seed; int increment; - bool has_increment; + const framework::Tensor* tensor_seed; + int seed_val; DropoutParam() { fix_seed = false; seed = 0; is_test = false; is_upscale_in_train = false; - has_increment = false; dropout_prob = 0.5; + tensor_seed = nullptr; + seed_val = 0; } /** - * dropout_index: the index of dropout, such as FFN has two dropout, - * so the dropout_index will 1 or 2. - * the dropout param will defined as param1 or param2 + * dropout_index: can be 0, 1, 2. 0 means there is only one dropout, + * 1 and 2 represent two dropout in FFN, the parameter name of dropout + * will be "dropout" + dropout_index + param name, such as dropout1_seed, + * dropout1_is_test. */ DropoutParam(const framework::ExecutionContext& context, const int dropout_index) { + std::string pre_fix = "dropout"; std::string str_index = std::to_string(dropout_index); - if (dropout_index == 0) { - str_index = ""; + if (dropout_index > 0) { + pre_fix = pre_fix + str_index + "_"; + } else { + pre_fix = pre_fix + "_"; } - dropout_prob = context.Attr("dropout_prob" + str_index); + dropout_prob = context.Attr(pre_fix + "prob"); auto& dropout_implementation = - context.Attr("dropout_implementation" + str_index); + context.Attr(pre_fix + "implementation"); is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test" + str_index); - fix_seed = context.Attr("fix_seed" + str_index); - has_increment = false; + is_test = context.Attr(pre_fix + "is_test"); + fix_seed = context.Attr(pre_fix + "fix_seed"); - std::string str_seed = "Seed" + str_index; - auto* tensor_seed = + std::string str_seed = "Dropout" + str_index + "Seed"; + tensor_seed = context.HasInput(str_seed) ? context.Input(str_seed) : nullptr; - int device_id = - BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (tensor_seed && platform::is_gpu_place(tensor_seed->place())) { - framework::Tensor seed_cpu_tensor; - TensorCopySync(*tensor_seed, platform::CPUPlace(), &seed_cpu_tensor); - seed = static_cast(seed_cpu_tensor.data()[0]); - } else if (gen_cuda->GetIsInitPy() && !fix_seed) { - has_increment = true; - } else { - if (tensor_seed) { - seed = *(tensor_seed->data()); - } else { - std::random_device rnd; - seed = fix_seed ? context.Attr("seed" + str_index) : rnd(); - } - } + seed_val = context.Attr(pre_fix + "seed"); } + int UpdateSeedAndIncrement(const platform::CUDADeviceContext& ctx, const int offset) { - int device_id = - BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - auto seed_offset = gen_cuda->IncrementOffset(offset); - seed = seed_offset.first; - increment = static_cast(seed_offset.second); + uint64_t tmp_increment; + GetSeedDataAndIncrement(ctx, tensor_seed, fix_seed, seed_val, offset, &seed, + &tmp_increment); + increment = static_cast(tmp_increment); return increment; } }; @@ -110,9 +101,7 @@ class FusedDropoutHelper { config.block_per_grid.x * real_vec_size) + 1) * real_vec_size; - if (dropout_param_.has_increment) { - increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment); - } + increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment); return increment; } @@ -132,18 +121,20 @@ class FusedDropoutHelper { auto increment = GetIncrement(ctx); LaunchResidualDropoutBias( rows_, cols_, increment, dropout_param_.seed, - dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, - dropout_param_.is_test, src, residual, bias, mask, out, ctx); + dropout_param_.dropout_prob, dropout_param_.is_test, + dropout_param_.is_upscale_in_train, src, residual, bias, mask, out, + ctx); } void ResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx, - const T* dout, const MaskType* mask, T* dsrc, - T* dresidual, T* dbias) { + const T* d_out, const MaskType* mask, T* d_src, + T* d_residual, T* d_bias) { LaunchResidualDropoutBiasGrad( - dout, mask, dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); - cudaMemcpyAsync(dresidual, dout, rows_ * cols_ * sizeof(T), - cudaMemcpyDeviceToDevice); + d_out, mask, dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); + auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(cuda_place, d_residual, cuda_place, d_out, + rows_ * cols_ * sizeof(T), ctx.stream()); } // out = dropout(activation(src + bias)) @@ -165,26 +156,26 @@ class FusedDropoutHelper { dropout_param_.is_test, src, bias, out, mask, ctx); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "the activation only support gelu or relu!")); + "Currently only supports gelu or relu activation functions!")); } } void DropoutActBiasGrad(const platform::CUDADeviceContext& ctx, const T* dout, const T* src, const T* bias, const MaskType* mask, - T* dsrc, T* dbias, const std::string& act_method) { + T* d_src, T* d_bias, const std::string& act_method) { if (act_method == "gelu") { GeluGradFunctor gelu_grad; LaunchDropoutActBiasGrad>( gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); } else if (act_method == "relu") { math::ReluGradFunctor relu_grad; LaunchDropoutActBiasGrad>( relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx); + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "the activation only support gelu or relu!")); + "Currently only supports gelu or relu activation functions!")); } } @@ -220,30 +211,24 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { const LayerNormParamType* gamma, const LayerNormParamType* beta, T* out, LayerNormParamType* mean, LayerNormParamType* variance) { -#ifdef __NVCC__ using U = LayerNormParamType; switch (GetDesiredBlockDim(this->cols_)) { FIXED_BLOCK_DIM_CASE( LayerNormForward< T, U, kBlockDim><<rows_, kBlockDim, 0, ctx.stream()>>>( src, gamma, beta, out, mean, variance, epsilon_, this->cols_)); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Product from begin_norm_axis to end must be larger than 1")); - break; } -#endif } void LayerNormGrad(const platform::CUDADeviceContext& ctx, const T* dout, const T* src, const LayerNormParamType* gamma, const LayerNormParamType* mean, - const LayerNormParamType* variance, T* dsrc, - LayerNormParamType* dscale, - LayerNormParamType* dbias) { + const LayerNormParamType* variance, T* d_src, + LayerNormParamType* d_scale, + LayerNormParamType* d_bias) { using U = LayerNormParamType; - LayerNormBackward(src, dout, gamma, mean, variance, dsrc, dscale, - dbias, epsilon_, this->rows_, this->cols_, ctx); + LayerNormBackward(src, dout, gamma, mean, variance, d_src, d_scale, + d_bias, epsilon_, this->rows_, this->cols_, ctx); } // out = layernorm(residual + dropout(src + bias)) @@ -252,42 +237,35 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { const T* bias, const LayerNormParamType* gamma, const LayerNormParamType* beta, T* dropout_out, MaskType* mask, T* out, LayerNormParamType* mean, LayerNormParamType* variance) { -#ifdef __NVCC__ using U = LayerNormParamType; - int VecSize = MAX_CACHE_BYTES / sizeof(T); - if (this->cols_ % VecSize != 0) { - VecSize = 1; + int vec_size = MAX_CACHE_BYTES / sizeof(T); + if (this->cols_ % vec_size != 0) { + vec_size = 1; } - int threads = GetDesiredBlockDim(this->cols_ / VecSize); - - int increment = ((this->cols_ - 1) / (threads * VecSize) + 1) * VecSize; - if (this->dropout_param_.has_increment) { - increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); - } - + int threads = GetDesiredBlockDim(this->cols_ / vec_size); + int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; + increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); LaunchLayernormResidualDropoutBias( this->rows_, this->cols_, increment, this->dropout_param_.seed, this->dropout_param_.dropout_prob, epsilon_, this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test, src, residual, bias, gamma, beta, mask, dropout_out, out, mean, variance, ctx); -#endif } void LayernormResidualDropoutBiasGrad( - const platform::CUDADeviceContext& ctx, const T* dout, const T* src, - const MaskType* mask, const LayerNormParamType* gamma, - const LayerNormParamType* mean, const LayerNormParamType* variance, - T* layernorm_dsrc, LayerNormParamType* dscale, - LayerNormParamType* layernorm_dbias, T* dsrc, T* dbias, T* dresidual) { -#ifdef __NVCC__ + const platform::CUDADeviceContext& ctx, const T* d_out, + const T* layernorm_src, const MaskType* mask, + const LayerNormParamType* gamma, const LayerNormParamType* mean, + const LayerNormParamType* variance, T* d_layernorm_src, + LayerNormParamType* d_scale, LayerNormParamType* d_layernorm_bias, + T* d_dropout_src, T* d_bias, T* d_residual) { using U = LayerNormParamType; - LayerNormBackward(src, dout, gamma, mean, variance, layernorm_dsrc, - dscale, layernorm_dbias, epsilon_, this->rows_, - this->cols_, ctx); - this->ResidualDropoutBiasGrad(ctx, layernorm_dsrc, mask, dsrc, dresidual, - dbias); -#endif + LayerNormBackward(layernorm_src, d_out, gamma, mean, variance, + d_layernorm_src, d_scale, d_layernorm_bias, + epsilon_, this->rows_, this->cols_, ctx); + this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, + d_residual, d_bias); } protected: diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index 7f949e49e27345..6f040e96c5ca70 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -54,98 +54,6 @@ def fused_multihead_attention(x, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) return final_out - else: - helper = LayerHelper('FusedMultiHeadAttention', **locals()) - dtype = x.dtype - # check dtypes - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'FusedMultiHeadAttention') - check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], - 'FusedMultiHeadAttention') - - # set inputs - inputs = dict() - inputs['X'] = [x] - if ln_scale: - inputs['LnScale'] = [ln_scale] - if ln_bias: - inputs['LnBias'] = [ln_bias] - inputs['QKVW'] = [qkv_weight] - inputs['QKVBias'] = [qkv_bias] - inputs['SrcMask'] = src_mask - inputs['OutLinearW'] = [out_linear_weight] - inputs['OutLinearBias'] = [out_linear_bias] - if ln_2_scale: - inputs['Ln2Scale'] = [ln_2_scale] - if ln_2_bias: - inputs['Ln2Bias'] = [ln_2_bias] - - # set attrs - attrs = { - 'pre_layer_norm': pre_layer_norm, - 'epsilon': epsilon, - 'ln2_epsilon': ln2_epsilon, - 'dropout_prob': dropout, - 'attn_dropout_prob': attn_dropout - } - - # set outputs - ln_mean_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_variance_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_out = helper.create_variable_for_type_inference(dtype=dtype) - - qkv_out = helper.create_variable_for_type_inference(dtype=dtype) - qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) - - transpose_out_2 = helper.create_variable_for_type_inference(dtype=dtype) - qk_out = helper.create_variable_for_type_inference(dtype=dtype) - qktv_out = helper.create_variable_for_type_inference(dtype=dtype) - softmax_out = helper.create_variable_for_type_inference(dtype=dtype) - attn_dropout_mask_out = helper.create_variable_for_type_inference( - dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) - attn_dropout_out = helper.create_variable_for_type_inference( - dtype=dtype) - src_mask_out = helper.create_variable_for_type_inference(dtype=dtype) - fmha_out = helper.create_variable_for_type_inference(dtype=dtype) - out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) - dropout_mask_out = helper.create_variable_for_type_inference( - dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) - ln_2_mean_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - ln_2_variance_out = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) - bias_dropout_residual_out = helper.create_variable_for_type_inference( - dtype=dtype) - final_out = helper.create_variable_for_type_inference(dtype=dtype) - - helper.append_op( - type='fused_attention', - inputs=inputs, - outputs={ - "LnMean": ln_mean_out, - "LnVariance": ln_variance_out, - "LnOut": ln_out, - "QKVOut": qkv_out, - "QKVBiasOut": qkv_bias_out, - "TransposeOut2": transpose_out_2, - "QKOut": qk_out, - "QKTVOut": qktv_out, - "SoftmaxOut": softmax_out, - "AttnDropoutMaskOut": attn_dropout_mask_out, - "AttnDropoutOut": attn_dropout_out, - "SrcMaskOut": src_mask_out, - "FMHAOut": fmha_out, - "OutLinearOut": out_linear_out, - "DropoutMaskOut": dropout_mask_out, - "Ln2Mean": ln_2_mean_out, - "Ln2Variance": ln_2_variance_out, - "BiasDropoutResidualOut": bias_dropout_residual_out, - 'Y': final_out - }, - attrs=attrs) - return final_out class FusedMultiHeadAttention(Layer): From 4dd4260f363a3abb00b6443e0a00c4bf8e399ca5 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Mon, 27 Sep 2021 12:58:57 +0000 Subject: [PATCH 15/29] Remove include fused_attention_op.h. --- .../fluid/operators/fused/fused_attention_op.cc | 1 - .../fluid/operators/fused/fused_attention_op.cu | 2 -- .../fluid/operators/fused/fused_attention_op.h | 16 ---------------- 3 files changed, 19 deletions(-) delete mode 100644 paddle/fluid/operators/fused/fused_attention_op.h diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index b481dc4be70af6..630f5491a99f92 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_attention_op.h" #include #include #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 8cea767f9745e1..32695d66310772 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -9,8 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_attention_op.h" - #include #include #include "paddle/fluid/platform/cuda_device_function.h" diff --git a/paddle/fluid/operators/fused/fused_attention_op.h b/paddle/fluid/operators/fused/fused_attention_op.h deleted file mode 100644 index 032df7818c77d1..00000000000000 --- a/paddle/fluid/operators/fused/fused_attention_op.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -namespace paddle { -namespace operators {} // namespace operators -} // namespace paddle From 2e3f4f26fbf0bdc6284e7261b1cb61a548d05a5c Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Mon, 27 Sep 2021 14:11:01 +0000 Subject: [PATCH 16/29] Polish names of variants. --- .../operators/fused/fused_attention_op.cc | 68 +++++++------- .../operators/fused/fused_attention_op.cu | 94 +++++++++---------- paddle/fluid/pybind/op_function_generator.cc | 12 +-- .../unittests/test_fused_attention_op.py | 29 +++--- python/paddle/nn/layer/fused_transformer.py | 38 ++++---- 5 files changed, 121 insertions(+), 120 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 630f5491a99f92..f819686e743f38 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -29,23 +29,23 @@ class FusedAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + OP_INOUT_CHECK(ctx->HasInput("LinearW"), "Input", "LinearW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + OP_INOUT_CHECK(ctx->HasInput("LinearBias"), "Input", "LinearBias", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + OP_INOUT_CHECK(ctx->HasOutput("PreLnMean"), "Output", "PreLnMean", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + OP_INOUT_CHECK(ctx->HasOutput("PreLnVariance"), "Output", "PreLnVariance", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", + OP_INOUT_CHECK(ctx->HasOutput("PreLnOut"), "Output", "PreLnOut", "FusedAttentionOp"); // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", + OP_INOUT_CHECK(ctx->HasOutput("TransposeOut"), "Output", "TransposeOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", "FusedAttentionOp"); @@ -61,11 +61,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", + OP_INOUT_CHECK(ctx->HasOutput("LinearOut"), "Output", "LinearOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", "BiasDropoutResidualOut", "FusedAttentionOp"); @@ -98,16 +98,16 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); - ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("PreLnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("PreLnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("PreLnOut", ctx->GetInputDim("X")); // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); // [3, batch_size, num_head, seq_len, head_size] - ctx->SetOutputDim("TransposeOut2", + ctx->SetOutputDim("TransposeOut", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch, num_head, seq_len, seq_len] ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); @@ -124,10 +124,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch_size, seq_len, number of heads*head size] ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); - ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("LinearOut", ctx->GetInputDim("X")); - ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get("dropout_is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } @@ -148,11 +148,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input tensor."); - AddInput("LnScale", + AddInput("PreLnScale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddInput("LnBias", + AddInput("PreLnBias", "(optional) Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); @@ -160,23 +160,23 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("QKVBias", "The qkv bias tensor."); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); - AddInput("OutLinearW", "The out_linear weight tensor."); - AddInput("OutLinearBias", "The out_linear bias tensor."); - AddInput("Ln2Scale", + AddInput("LinearW", "The linear weight tensor."); + AddInput("LinearBias", "The linear bias tensor."); + AddInput("LnScale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddInput("Ln2Bias", + AddInput("LnBias", "(optional) Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); - AddOutput("LnVariance", "Variance of the current mini batch.") + AddOutput("PreLnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("PreLnVariance", "Variance of the current mini batch.") .AsIntermediate(); - AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); + AddOutput("PreLnOut", "The output of pre layer_norm.").AsIntermediate(); AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); - AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); + AddOutput("TransposeOut", "Result in fmha.").AsIntermediate(); AddOutput("QKOut", "Result in fmha.").AsIntermediate(); AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); @@ -184,11 +184,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); - AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); + AddOutput("LinearOut", "Result after linear.").AsIntermediate(); AddOutput("DropoutMaskOut", "The random sampled dropout mask.") .AsIntermediate(); - AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); - AddOutput("Ln2Variance", "Variance of the current mini batch.") + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") .AsIntermediate(); AddOutput("BiasDropoutResidualOut", "Result of residual + dropout(src + bias).") @@ -289,16 +289,16 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "dropout_implementation can only be downgrade_in_infer or " "upscale_in_train")); }); - AddAttr("ln2epsilon", + AddAttr("ln_epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) - .AddCustomChecker([](const float &ln2epsilon) { - PADDLE_ENFORCE_EQ(ln2epsilon >= 0.0f && ln2epsilon <= 0.001f, true, + .AddCustomChecker([](const float &ln_epsilon) { + PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, platform::errors::InvalidArgument( "'epsilon' of the second LayerNorm in Fused " "attention op should be between" "0.0 and 0.001, But received [%s].", - ln2epsilon)); + ln_epsilon)); }); AddComment(R"DOC( @@ -319,7 +319,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { out = transpose(out, perm=[0, 2, 1, 3]); } - out = out_linear(out); + out = linear(out); final_out = layer_norm(residual + dropout(bias + out)); )DOC"); } diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 32695d66310772..3c85809b75c69d 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -38,11 +38,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto *ln_scale = ctx.Input("LnScale"); - auto *ln_bias = ctx.Input("LnBias"); - auto *ln_mean = ctx.Output("LnMean"); - auto *ln_var = ctx.Output("LnVariance"); - auto *ln_out = ctx.Output("LnOut"); + auto *pre_ln_scale = ctx.Input("PreLnScale"); + auto *pre_ln_bias = ctx.Input("PreLnBias"); + auto *pre_ln_mean = ctx.Output("PreLnMean"); + auto *pre_ln_var = ctx.Output("PreLnVariance"); + auto *pre_ln_out = ctx.Output("PreLnOut"); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] @@ -52,7 +52,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out = ctx.Output("QKVBiasOut"); auto *src_mask = ctx.Input("SrcMask"); - auto *transpose_out_2 = ctx.Output("TransposeOut2"); + auto *transpose_out = ctx.Output("TransposeOut"); auto *qk_out = ctx.Output("QKOut"); auto *qktv_out = ctx.Output("QKTVOut"); auto *softmax_out = ctx.Output("SoftmaxOut"); @@ -61,18 +61,18 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *src_mask_out = ctx.Output("SrcMaskOut"); auto *fmha_out = ctx.Output("FMHAOut"); - auto *out_linear_weight = ctx.Input("OutLinearW"); - auto *out_linear_bias = ctx.Input("OutLinearBias"); - auto *out_linear_out = ctx.Output("OutLinearOut"); + auto *linear_weight = ctx.Input("LinearW"); + auto *linear_bias = ctx.Input("LinearBias"); + auto *linear_out = ctx.Output("LinearOut"); - auto *ln_scale_2 = ctx.Input("Ln2Scale"); - auto *ln_bias_2 = ctx.Input("Ln2Bias"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Output("BiasDropoutResidualOut"); - auto *ln_mean_2 = ctx.Output("Ln2Mean"); - auto *ln_var_2 = ctx.Output("Ln2Variance"); - const float ln2epsilon = ctx.Attr("ln2epsilon"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + const float ln_epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); bool attn_dropout_is_test = ctx.Attr("attn_dropout_is_test"); @@ -94,11 +94,13 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); - auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); - auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); - auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + auto *pre_ln_scale_data = + (pre_ln_scale == nullptr ? nullptr : pre_ln_scale->data()); + auto *pre_ln_bias_data = + (pre_ln_bias == nullptr ? nullptr : pre_ln_bias->data()); + auto *pre_ln_mean_data = pre_ln_mean->mutable_data(ctx.GetPlace()); + auto *pre_ln_var_data = pre_ln_var->mutable_data(ctx.GetPlace()); + auto *pre_ln_out_data = pre_ln_out->mutable_data(ctx.GetPlace()); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = qkv_bias->data(); @@ -106,8 +108,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. - auto *transpose_out_2_data = - transpose_out_2->mutable_data(ctx.GetPlace()); + auto *transpose_out_data = transpose_out->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); @@ -118,22 +119,20 @@ class FusedAttentionOpKernel : public framework::OpKernel { attn_dropout_out->mutable_data(ctx.GetPlace()); auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); - // get data ptr for out_linear. - auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); - auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); + // get data ptr for linear. + auto *linear_weight_data = linear_weight->data(); + auto *linear_bias_data = linear_bias->data(); + auto *linear_out_data = linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm - auto *ln_scale_2_data = - (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); - auto *ln_bias_2_data = - (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); auto *bias_dropout_residual_out_data = bias_dropout_residual_out->mutable_data(ctx.GetPlace()); - auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); - auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; @@ -164,38 +163,39 @@ class FusedAttentionOpKernel : public framework::OpKernel { output_size = hidden_size; // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = + auto linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, output_size, input_size, false); - DropoutParam dropout_param2(ctx, 0); + DropoutParam dropout_param(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, - ln2epsilon); + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param, + ln_epsilon); if (pre_layer_norm) { - layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, - ln_out_data, ln_mean_data, ln_var_data); - qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, - qkv_out_data, qkv_bias_out_data); + layer_norm_compute.ComputeForward(x_data, pre_ln_scale_data, + pre_ln_bias_data, pre_ln_out_data, + pre_ln_mean_data, pre_ln_var_data); + qkv_compute.ComputeForward(qkv_weight_data, pre_ln_out_data, + qkv_bias_data, qkv_out_data, + qkv_bias_out_data); } else { qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_out_data, qkv_bias_out_data); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] - // out_linear_out: [batch_size, seq_len, embed_dim] - out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, - nullptr, out_linear_out_data, nullptr); + // linear_out: [batch_size, seq_len, embed_dim] + linear_compute.ComputeForward(linear_weight_data, fmha_out_data, nullptr, + linear_out_data, nullptr); // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx.cuda_device_context(), out_linear_out_data, x_data, - out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, - bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, - ln_mean_2_data, ln_var_2_data); + ctx.cuda_device_context(), linear_out_data, x_data, linear_bias_data, + ln_scale_data, ln_bias_data, bias_dropout_residual_out_data, + dropout_mask_out_data, final_out_data, ln_mean_data, ln_var_data); } }; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 5a3d99239750bd..282bc3b8e4cd59 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,8 +41,8 @@ std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, {"fused_attention", - {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", - "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, + {"X", "PreLnScale", "PreLnBias", "QKVW", "QKVBias", "SrcMask", "LinearW", + "LinearBias", "LnScale", "LnBias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -91,10 +91,10 @@ std::map> op_outs_map = { {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, {"fused_attention", - {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", - "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", - "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", - "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, + {"PreLnMean", "PreLnVariance", "PreLnOut", "QKVOut", "QKVBiasOut", + "TransposeOut", "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", + "AttnDropoutOut", "SrcMaskOut", "FMHAOut", "LinearOut", "DropoutMaskOut", + "LnMean", "LnVariance", "BiasDropoutResidualOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 277d68dada8bb5..39733d3e3d7796 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -27,19 +27,19 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, - query, attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, - qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, + query, attn_mask, pre_ln_scale, pre_ln_bias, ln_scale, + ln_bias, qkv_weight, qkv_bias, linear_weight, linear_bias, attn_dropout_prob, dropout_prob): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(query, stop_gradient=False) attn_mask = paddle.to_tensor(attn_mask, stop_gradient=False) residual = tensor_query + pre_ln_scale = paddle.to_tensor(pre_ln_scale) + pre_ln_bias = paddle.to_tensor(pre_ln_bias) ln_scale = paddle.to_tensor(ln_scale) ln_bias = paddle.to_tensor(ln_bias) - ln_2_scale = paddle.to_tensor(ln_2_scale) - ln_2_bias = paddle.to_tensor(ln_2_bias) - out_linear_weight = paddle.to_tensor(out_linear_weight) - out_linear_bias = paddle.to_tensor(out_linear_bias) + linear_weight = paddle.to_tensor(linear_weight) + linear_bias = paddle.to_tensor(linear_bias) # qkv_weight: [3, num_heads, self.head_dim, embed_dim] q_weight = qkv_weight[0:1, ::] @@ -65,7 +65,8 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, for i in range(1): ln1_out = tensor_query if pre_layer_norm: - ln1_out = F.layer_norm(tensor_query, embed_dim, ln_scale, ln_bias) + ln1_out = F.layer_norm(tensor_query, embed_dim, pre_ln_scale, + pre_ln_bias) q = F.linear(ln1_out, q_weight, q_bias) q = tensor.reshape(x=q, shape=[0, 0, num_heads, head_dim]) @@ -98,13 +99,13 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, qktv_out = tensor.matmul(softmax_out, v_out) fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - out_linear_in = tensor.reshape( + linear_in = tensor.reshape( x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - out = F.linear(out_linear_in, out_linear_weight, out_linear_bias) + out = F.linear(linear_in, linear_weight, linear_bias) residual_out = residual + F.dropout( out, dropout_prob, training=training, mode="upscale_in_train") - final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) + final_out = F.layer_norm(residual_out, embed_dim, ln_scale, ln_bias) return final_out @@ -164,14 +165,14 @@ def compute_result(self): ref_out = GetBaselineOut(self.pre_layer_norm, self.training, self.embed_dim, self.num_heads, self.head_dim, self.query, self.attn_mask, + fused_attn.pre_ln_scale.numpy(), + fused_attn.pre_ln_bias.numpy(), fused_attn.ln_scale.numpy(), fused_attn.ln_bias.numpy(), - fused_attn.ln_2_scale.numpy(), - fused_attn.ln_2_bias.numpy(), fused_attn.qkv_weight.numpy(), fused_attn.qkv_bias.numpy(), - fused_attn.out_linear_weight.numpy(), - fused_attn.out_linear_bias.numpy(), + fused_attn.linear_weight.numpy(), + fused_attn.linear_bias.numpy(), self.attn_dropout_prob, self.dropout_prob) return ref_out, out diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index 6f040e96c5ca70..a2d3246588efe7 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -31,15 +31,15 @@ def fused_multihead_attention(x, qkv_weight, - out_linear_weight, + linear_weight, pre_layer_norm=False, + pre_ln_scale=None, + pre_ln_bias=None, ln_scale=None, ln_bias=None, - ln_2_scale=None, - ln_2_bias=None, epsilon=1e-05, qkv_bias=None, - out_linear_bias=None, + linear_bias=None, src_mask=None, dropout=0., attn_dropout=0., @@ -48,11 +48,11 @@ def fused_multihead_attention(x, r""" """ if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, - out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, - 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, src_mask, + linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', + pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout, + 'attn_dropout_prob', attn_dropout) return final_out @@ -103,12 +103,12 @@ def __init__(self, attr=self._bias_attr, dtype=self._dtype, is_bias=True) - self.out_linear_weight = self.create_parameter( + self.linear_weight = self.create_parameter( shape=[embed_dim, embed_dim], attr=self._weight_attr, dtype=self._dtype, is_bias=False) - self.out_linear_bias = self.create_parameter( + self.linear_bias = self.create_parameter( shape=[embed_dim], attr=self._bias_attr, dtype=self._dtype, @@ -117,17 +117,17 @@ def __init__(self, if get_default_dtype() == 'float16': set_default_dtype('float32') ## layer_norm parameters. - self.ln_scale = self.create_parameter( + self.pre_ln_scale = self.create_parameter( attr=self._weight_attr, shape=[embed_dim], default_initializer=Constant(value=1.0)) - self.ln_bias = self.create_parameter( + self.pre_ln_bias = self.create_parameter( attr=self._bias_attr, shape=[embed_dim], is_bias=True) - self.ln_2_scale = self.create_parameter( + self.ln_scale = self.create_parameter( attr=self._weight_attr, shape=[embed_dim], default_initializer=Constant(value=1.0)) - self.ln_2_bias = self.create_parameter( + self.ln_bias = self.create_parameter( attr=self._bias_attr, shape=[embed_dim], is_bias=True) if get_default_dtype() == 'float16': set_default_dtype('float16') @@ -147,15 +147,15 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): out = fused_multihead_attention( x=query, qkv_weight=self.qkv_weight, - out_linear_weight=self.out_linear_weight, + linear_weight=self.linear_weight, pre_layer_norm=self.normalize_before, + pre_ln_scale=self.pre_ln_scale, + pre_ln_bias=self.pre_ln_bias, ln_scale=self.ln_scale, ln_bias=self.ln_bias, - ln_2_scale=self.ln_2_scale, - ln_2_bias=self.ln_2_bias, epsilon=1e-05, qkv_bias=self.qkv_bias, - out_linear_bias=self.out_linear_bias, + linear_bias=self.linear_bias, src_mask=attn_mask, dropout=self.dropout, attn_dropout=self.attn_dropout, From 13d4ff3903f2c3c0567e23d403162787d88553d9 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 13 Oct 2021 09:31:17 +0000 Subject: [PATCH 17/29] Revert "Polish names of variants." This reverts commit 2e3f4f26fbf0bdc6284e7261b1cb61a548d05a5c. --- .../operators/fused/fused_attention_op.cc | 68 +++++++------- .../operators/fused/fused_attention_op.cu | 94 +++++++++---------- paddle/fluid/pybind/op_function_generator.cc | 12 +-- .../unittests/test_fused_attention_op.py | 29 +++--- python/paddle/nn/layer/fused_transformer.py | 38 ++++---- 5 files changed, 120 insertions(+), 121 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index f819686e743f38..630f5491a99f92 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -29,23 +29,23 @@ class FusedAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("LinearW"), "Input", "LinearW", + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("LinearBias"), "Input", "LinearBias", + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("PreLnMean"), "Output", "PreLnMean", + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("PreLnVariance"), "Output", "PreLnVariance", + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("PreLnOut"), "Output", "PreLnOut", + OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp"); // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("TransposeOut"), "Output", "TransposeOut", + OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", "FusedAttentionOp"); @@ -61,11 +61,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LinearOut"), "Output", "LinearOut", + OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", "BiasDropoutResidualOut", "FusedAttentionOp"); @@ -98,16 +98,16 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); - ctx->SetOutputDim("PreLnMean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("PreLnVariance", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("PreLnOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); // [3, batch_size, num_head, seq_len, head_size] - ctx->SetOutputDim("TransposeOut", + ctx->SetOutputDim("TransposeOut2", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch, num_head, seq_len, seq_len] ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); @@ -124,10 +124,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch_size, seq_len, number of heads*head size] ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); - ctx->SetOutputDim("LinearOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); - ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get("dropout_is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } @@ -148,11 +148,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input tensor."); - AddInput("PreLnScale", + AddInput("LnScale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddInput("PreLnBias", + AddInput("LnBias", "(optional) Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); @@ -160,23 +160,23 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("QKVBias", "The qkv bias tensor."); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); - AddInput("LinearW", "The linear weight tensor."); - AddInput("LinearBias", "The linear bias tensor."); - AddInput("LnScale", + AddInput("OutLinearW", "The out_linear weight tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor."); + AddInput("Ln2Scale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddInput("LnBias", + AddInput("Ln2Bias", "(optional) Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); - AddOutput("PreLnMean", "Mean of the current mini batch.").AsIntermediate(); - AddOutput("PreLnVariance", "Variance of the current mini batch.") + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") .AsIntermediate(); - AddOutput("PreLnOut", "The output of pre layer_norm.").AsIntermediate(); + AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); - AddOutput("TransposeOut", "Result in fmha.").AsIntermediate(); + AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); AddOutput("QKOut", "Result in fmha.").AsIntermediate(); AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); @@ -184,11 +184,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); - AddOutput("LinearOut", "Result after linear.").AsIntermediate(); + AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); AddOutput("DropoutMaskOut", "The random sampled dropout mask.") .AsIntermediate(); - AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); - AddOutput("LnVariance", "Variance of the current mini batch.") + AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("Ln2Variance", "Variance of the current mini batch.") .AsIntermediate(); AddOutput("BiasDropoutResidualOut", "Result of residual + dropout(src + bias).") @@ -289,16 +289,16 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "dropout_implementation can only be downgrade_in_infer or " "upscale_in_train")); }); - AddAttr("ln_epsilon", + AddAttr("ln2epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) - .AddCustomChecker([](const float &ln_epsilon) { - PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, + .AddCustomChecker([](const float &ln2epsilon) { + PADDLE_ENFORCE_EQ(ln2epsilon >= 0.0f && ln2epsilon <= 0.001f, true, platform::errors::InvalidArgument( "'epsilon' of the second LayerNorm in Fused " "attention op should be between" "0.0 and 0.001, But received [%s].", - ln_epsilon)); + ln2epsilon)); }); AddComment(R"DOC( @@ -319,7 +319,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { out = transpose(out, perm=[0, 2, 1, 3]); } - out = linear(out); + out = out_linear(out); final_out = layer_norm(residual + dropout(bias + out)); )DOC"); } diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 3c85809b75c69d..32695d66310772 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -38,11 +38,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto *pre_ln_scale = ctx.Input("PreLnScale"); - auto *pre_ln_bias = ctx.Input("PreLnBias"); - auto *pre_ln_mean = ctx.Output("PreLnMean"); - auto *pre_ln_var = ctx.Output("PreLnVariance"); - auto *pre_ln_out = ctx.Output("PreLnOut"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + auto *ln_out = ctx.Output("LnOut"); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] @@ -52,7 +52,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out = ctx.Output("QKVBiasOut"); auto *src_mask = ctx.Input("SrcMask"); - auto *transpose_out = ctx.Output("TransposeOut"); + auto *transpose_out_2 = ctx.Output("TransposeOut2"); auto *qk_out = ctx.Output("QKOut"); auto *qktv_out = ctx.Output("QKTVOut"); auto *softmax_out = ctx.Output("SoftmaxOut"); @@ -61,18 +61,18 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *src_mask_out = ctx.Output("SrcMaskOut"); auto *fmha_out = ctx.Output("FMHAOut"); - auto *linear_weight = ctx.Input("LinearW"); - auto *linear_bias = ctx.Input("LinearBias"); - auto *linear_out = ctx.Output("LinearOut"); + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *out_linear_out = ctx.Output("OutLinearOut"); - auto *ln_scale = ctx.Input("LnScale"); - auto *ln_bias = ctx.Input("LnBias"); + auto *ln_scale_2 = ctx.Input("Ln2Scale"); + auto *ln_bias_2 = ctx.Input("Ln2Bias"); auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Output("BiasDropoutResidualOut"); - auto *ln_mean = ctx.Output("LnMean"); - auto *ln_var = ctx.Output("LnVariance"); - const float ln_epsilon = ctx.Attr("ln_epsilon"); + auto *ln_mean_2 = ctx.Output("Ln2Mean"); + auto *ln_var_2 = ctx.Output("Ln2Variance"); + const float ln2epsilon = ctx.Attr("ln2epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); bool attn_dropout_is_test = ctx.Attr("attn_dropout_is_test"); @@ -94,13 +94,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); - auto *pre_ln_scale_data = - (pre_ln_scale == nullptr ? nullptr : pre_ln_scale->data()); - auto *pre_ln_bias_data = - (pre_ln_bias == nullptr ? nullptr : pre_ln_bias->data()); - auto *pre_ln_mean_data = pre_ln_mean->mutable_data(ctx.GetPlace()); - auto *pre_ln_var_data = pre_ln_var->mutable_data(ctx.GetPlace()); - auto *pre_ln_out_data = pre_ln_out->mutable_data(ctx.GetPlace()); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = qkv_bias->data(); @@ -108,7 +106,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. - auto *transpose_out_data = transpose_out->mutable_data(ctx.GetPlace()); + auto *transpose_out_2_data = + transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); @@ -119,20 +118,22 @@ class FusedAttentionOpKernel : public framework::OpKernel { attn_dropout_out->mutable_data(ctx.GetPlace()); auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); - // get data ptr for linear. - auto *linear_weight_data = linear_weight->data(); - auto *linear_bias_data = linear_bias->data(); - auto *linear_out_data = linear_out->mutable_data(ctx.GetPlace()); + // get data ptr for out_linear. + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); - auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); auto *bias_dropout_residual_out_data = bias_dropout_residual_out->mutable_data(ctx.GetPlace()); - auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); - auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; @@ -163,39 +164,38 @@ class FusedAttentionOpKernel : public framework::OpKernel { output_size = hidden_size; // (transA, transB, compute_bias) = (false, false, false) - auto linear_compute = + auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, output_size, input_size, false); - DropoutParam dropout_param(ctx, 0); + DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param, - ln_epsilon); + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln2epsilon); if (pre_layer_norm) { - layer_norm_compute.ComputeForward(x_data, pre_ln_scale_data, - pre_ln_bias_data, pre_ln_out_data, - pre_ln_mean_data, pre_ln_var_data); - qkv_compute.ComputeForward(qkv_weight_data, pre_ln_out_data, - qkv_bias_data, qkv_out_data, - qkv_bias_out_data); + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, + ln_out_data, ln_mean_data, ln_var_data); + qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); } else { qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_out_data, qkv_bias_out_data); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out, + fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] - // linear_out: [batch_size, seq_len, embed_dim] - linear_compute.ComputeForward(linear_weight_data, fmha_out_data, nullptr, - linear_out_data, nullptr); + // out_linear_out: [batch_size, seq_len, embed_dim] + out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, + nullptr, out_linear_out_data, nullptr); // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx.cuda_device_context(), linear_out_data, x_data, linear_bias_data, - ln_scale_data, ln_bias_data, bias_dropout_residual_out_data, - dropout_mask_out_data, final_out_data, ln_mean_data, ln_var_data); + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); } }; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index b42ed2eb552f74..53c7e165d84333 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,8 +41,8 @@ std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, {"fused_attention", - {"X", "PreLnScale", "PreLnBias", "QKVW", "QKVBias", "SrcMask", "LinearW", - "LinearBias", "LnScale", "LnBias"}}, + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", + "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -95,10 +95,10 @@ std::map> op_outs_map = { {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, {"fused_attention", - {"PreLnMean", "PreLnVariance", "PreLnOut", "QKVOut", "QKVBiasOut", - "TransposeOut", "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", - "AttnDropoutOut", "SrcMaskOut", "FMHAOut", "LinearOut", "DropoutMaskOut", - "LnMean", "LnVariance", "BiasDropoutResidualOut", "Y"}}, + {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", + "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 39733d3e3d7796..277d68dada8bb5 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -27,19 +27,19 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, - query, attn_mask, pre_ln_scale, pre_ln_bias, ln_scale, - ln_bias, qkv_weight, qkv_bias, linear_weight, linear_bias, + query, attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, + qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, attn_dropout_prob, dropout_prob): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(query, stop_gradient=False) attn_mask = paddle.to_tensor(attn_mask, stop_gradient=False) residual = tensor_query - pre_ln_scale = paddle.to_tensor(pre_ln_scale) - pre_ln_bias = paddle.to_tensor(pre_ln_bias) ln_scale = paddle.to_tensor(ln_scale) ln_bias = paddle.to_tensor(ln_bias) - linear_weight = paddle.to_tensor(linear_weight) - linear_bias = paddle.to_tensor(linear_bias) + ln_2_scale = paddle.to_tensor(ln_2_scale) + ln_2_bias = paddle.to_tensor(ln_2_bias) + out_linear_weight = paddle.to_tensor(out_linear_weight) + out_linear_bias = paddle.to_tensor(out_linear_bias) # qkv_weight: [3, num_heads, self.head_dim, embed_dim] q_weight = qkv_weight[0:1, ::] @@ -65,8 +65,7 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, for i in range(1): ln1_out = tensor_query if pre_layer_norm: - ln1_out = F.layer_norm(tensor_query, embed_dim, pre_ln_scale, - pre_ln_bias) + ln1_out = F.layer_norm(tensor_query, embed_dim, ln_scale, ln_bias) q = F.linear(ln1_out, q_weight, q_bias) q = tensor.reshape(x=q, shape=[0, 0, num_heads, head_dim]) @@ -99,13 +98,13 @@ def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, qktv_out = tensor.matmul(softmax_out, v_out) fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - linear_in = tensor.reshape( + out_linear_in = tensor.reshape( x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - out = F.linear(linear_in, linear_weight, linear_bias) + out = F.linear(out_linear_in, out_linear_weight, out_linear_bias) residual_out = residual + F.dropout( out, dropout_prob, training=training, mode="upscale_in_train") - final_out = F.layer_norm(residual_out, embed_dim, ln_scale, ln_bias) + final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) return final_out @@ -165,14 +164,14 @@ def compute_result(self): ref_out = GetBaselineOut(self.pre_layer_norm, self.training, self.embed_dim, self.num_heads, self.head_dim, self.query, self.attn_mask, - fused_attn.pre_ln_scale.numpy(), - fused_attn.pre_ln_bias.numpy(), fused_attn.ln_scale.numpy(), fused_attn.ln_bias.numpy(), + fused_attn.ln_2_scale.numpy(), + fused_attn.ln_2_bias.numpy(), fused_attn.qkv_weight.numpy(), fused_attn.qkv_bias.numpy(), - fused_attn.linear_weight.numpy(), - fused_attn.linear_bias.numpy(), + fused_attn.out_linear_weight.numpy(), + fused_attn.out_linear_bias.numpy(), self.attn_dropout_prob, self.dropout_prob) return ref_out, out diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index a2d3246588efe7..6f040e96c5ca70 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -31,15 +31,15 @@ def fused_multihead_attention(x, qkv_weight, - linear_weight, + out_linear_weight, pre_layer_norm=False, - pre_ln_scale=None, - pre_ln_bias=None, ln_scale=None, ln_bias=None, + ln_2_scale=None, + ln_2_bias=None, epsilon=1e-05, qkv_bias=None, - linear_bias=None, + out_linear_bias=None, src_mask=None, dropout=0., attn_dropout=0., @@ -48,11 +48,11 @@ def fused_multihead_attention(x, r""" """ if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, src_mask, - linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', - pre_layer_norm, 'epsilon', epsilon, 'dropout_prob', dropout, - 'attn_dropout_prob', attn_dropout) + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, + out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) return final_out @@ -103,12 +103,12 @@ def __init__(self, attr=self._bias_attr, dtype=self._dtype, is_bias=True) - self.linear_weight = self.create_parameter( + self.out_linear_weight = self.create_parameter( shape=[embed_dim, embed_dim], attr=self._weight_attr, dtype=self._dtype, is_bias=False) - self.linear_bias = self.create_parameter( + self.out_linear_bias = self.create_parameter( shape=[embed_dim], attr=self._bias_attr, dtype=self._dtype, @@ -117,17 +117,17 @@ def __init__(self, if get_default_dtype() == 'float16': set_default_dtype('float32') ## layer_norm parameters. - self.pre_ln_scale = self.create_parameter( + self.ln_scale = self.create_parameter( attr=self._weight_attr, shape=[embed_dim], default_initializer=Constant(value=1.0)) - self.pre_ln_bias = self.create_parameter( + self.ln_bias = self.create_parameter( attr=self._bias_attr, shape=[embed_dim], is_bias=True) - self.ln_scale = self.create_parameter( + self.ln_2_scale = self.create_parameter( attr=self._weight_attr, shape=[embed_dim], default_initializer=Constant(value=1.0)) - self.ln_bias = self.create_parameter( + self.ln_2_bias = self.create_parameter( attr=self._bias_attr, shape=[embed_dim], is_bias=True) if get_default_dtype() == 'float16': set_default_dtype('float16') @@ -147,15 +147,15 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): out = fused_multihead_attention( x=query, qkv_weight=self.qkv_weight, - linear_weight=self.linear_weight, + out_linear_weight=self.out_linear_weight, pre_layer_norm=self.normalize_before, - pre_ln_scale=self.pre_ln_scale, - pre_ln_bias=self.pre_ln_bias, ln_scale=self.ln_scale, ln_bias=self.ln_bias, + ln_2_scale=self.ln_2_scale, + ln_2_bias=self.ln_2_bias, epsilon=1e-05, qkv_bias=self.qkv_bias, - linear_bias=self.linear_bias, + out_linear_bias=self.out_linear_bias, src_mask=attn_mask, dropout=self.dropout, attn_dropout=self.attn_dropout, From 300ec354ea8e61919ed425e0414fa85e6b27d2ee Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 13 Oct 2021 11:31:33 +0000 Subject: [PATCH 18/29] Revert "Modifications accordding to Xreki's review." This reverts commit 8a4c2a81aa19f93e49614bb5a7b18d920cb2d963. --- .../operators/fused/fused_attention_op.cc | 24 +- .../operators/fused/fused_attention_op.cu | 60 ++-- .../unittests/test_fused_attention_op.py | 292 ++++++++++-------- python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/common.py | 27 ++ python/paddle/nn/layer/fused_transformer.py | 220 ++++++------- .../static_mode_white_list.cpython-37.pyc | Bin 0 -> 21041 bytes 7 files changed, 338 insertions(+), 286 deletions(-) create mode 100644 tools/__pycache__/static_mode_white_list.cpython-37.pyc diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 630f5491a99f92..8e5263091e48e7 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -115,7 +115,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - if (ctx->Attrs().Get("attn_dropout_is_test") == false) { + if (ctx->Attrs().Get("is_test1") == false) { ctx->SetOutputDim("AttnDropoutMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); } @@ -220,20 +220,20 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { platform::errors::InvalidArgument( "'attn_dropout_prob' must be between 0.0 and 1.0.")); }); - AddAttr("attn_dropout_is_test", + AddAttr("is_test1", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("attn_dropout_fix_seed", + AddAttr("fix_seed1", "A flag indicating whether to use a fixed seed to generate " "random mask. NOTE: DO NOT set this flag to true in " "training. Setting this flag to true is only useful in " "unittest or for debug that always the same output units " "will be dropped.") .SetDefault(true); - AddAttr("attn_dropout_seed_val", "Dropout random seed.").SetDefault(0); + AddAttr("seed1", "Dropout random seed.").SetDefault(0); AddAttr( - "attn_dropout_implementation", + "dropout_implementation1", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" "There are two kinds of ways to implement dropout" "(the mask below is a tensor have the same shape with input" @@ -280,7 +280,19 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" - "The meaning is the same as \"attn_dropout_implementation\" attribute.") + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * (1.0 - dropout_prob)" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") .SetDefault("downgrade_in_infer") .AddCustomChecker([](const std::string &type) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 32695d66310772..e99fc1c7b94af4 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -9,13 +9,26 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#ifdef __NVCC__ #include -#include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/cuda_device_function.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif + +#include #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/math/math_function.h" @@ -75,16 +88,14 @@ class FusedAttentionOpKernel : public framework::OpKernel { const float ln2epsilon = ctx.Attr("ln2epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); - bool attn_dropout_is_test = ctx.Attr("attn_dropout_is_test"); - auto &attn_dropout_implementation = - ctx.Attr("attn_dropout_implementation"); - bool attn_dropout_is_upscale_in_train = - (attn_dropout_implementation == "upscale_in_train"); - auto *attn_dropout_seed = ctx.HasInput("AttnDropoutSeed") - ? ctx.Input("AttnDropoutSeed") - : nullptr; - bool attn_dropout_fix_seed = ctx.Attr("attn_dropout_fix_seed"); - int attn_dropout_seed_val = ctx.Attr("attn_dropout_seed_val"); + bool is_test_1 = ctx.Attr("is_test1"); + auto &dropout_implementation_1 = + ctx.Attr("dropout_implementation1"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("fix_seed1"); + int seed_val_1 = ctx.Attr("seed1"); // final output. auto *out = ctx.Output("Y"); @@ -106,6 +117,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. + auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); @@ -148,25 +160,29 @@ class FusedAttentionOpKernel : public framework::OpKernel { int output_size = 3 * hidden_size; int input_size = dim_embed; + bool transA = false; + bool transB = true; + bool compute_bias = true; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); - // (transA, transB, compute_bias) = (false, true, true) - auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, - bsz_seq, output_size, input_size, true); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); AttnDropoutParam attn_dropout_param( - attn_dropout_is_test, attn_dropout_implementation, attn_dropout_prob, - attn_dropout_is_upscale_in_train, attn_dropout_fix_seed, - attn_dropout_seed_val, attn_dropout_seed); + is_test_1, dropout_implementation_1, attn_dropout_prob, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; - // (transA, transB, compute_bias) = (false, false, false) + transA = false; + transB = false; + compute_bias = false; auto out_linear_compute = - AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, - output_size, input_size, false); + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 277d68dada8bb5..bf26e05c844e49 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -13,128 +13,70 @@ # limitations under the License. import numpy as np + import paddle import paddle.nn as nn import paddle.fluid.core as core import paddle.nn.functional as F -from paddle.nn.layer.fused_transformer import FusedMultiHeadAttention +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout from paddle.nn.layer.transformer import _convert_attention_mask from paddle import tensor from paddle.fluid import layers -from paddle.static import Program, program_guard import unittest -from op_test import OpTest - - -def GetBaselineOut(pre_layer_norm, training, embed_dim, num_heads, head_dim, - query, attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, - qkv_weight, qkv_bias, out_linear_weight, out_linear_bias, - attn_dropout_prob, dropout_prob): - paddle.disable_static(place=paddle.CUDAPlace(0)) - tensor_query = paddle.to_tensor(query, stop_gradient=False) - attn_mask = paddle.to_tensor(attn_mask, stop_gradient=False) - residual = tensor_query - ln_scale = paddle.to_tensor(ln_scale) - ln_bias = paddle.to_tensor(ln_bias) - ln_2_scale = paddle.to_tensor(ln_2_scale) - ln_2_bias = paddle.to_tensor(ln_2_bias) - out_linear_weight = paddle.to_tensor(out_linear_weight) - out_linear_bias = paddle.to_tensor(out_linear_bias) - - # qkv_weight: [3, num_heads, self.head_dim, embed_dim] - q_weight = qkv_weight[0:1, ::] - k_weight = qkv_weight[1:2, ::] - v_weight = qkv_weight[2:3, ::] - q_weight = q_weight.reshape(num_heads * head_dim, embed_dim) - k_weight = k_weight.reshape(num_heads * head_dim, embed_dim) - v_weight = v_weight.reshape(num_heads * head_dim, embed_dim) - q_weight = paddle.to_tensor(q_weight.transpose((1, 0))) - k_weight = paddle.to_tensor(k_weight.transpose((1, 0))) - v_weight = paddle.to_tensor(v_weight.transpose((1, 0))) - # qkv_bias: [3, num_heads, self.head_dim] - q_bias = qkv_bias[0:1, ::] - q_bias = q_bias.reshape(num_heads * head_dim) - k_bias = qkv_bias[1:2, ::] - k_bias = k_bias.reshape(num_heads * head_dim) - v_bias = qkv_bias[2:3, ::] - v_bias = v_bias.reshape(num_heads * head_dim) - q_bias = paddle.to_tensor(q_bias) - k_bias = paddle.to_tensor(k_bias) - v_bias = paddle.to_tensor(v_bias) - - for i in range(1): - ln1_out = tensor_query - if pre_layer_norm: - ln1_out = F.layer_norm(tensor_query, embed_dim, ln_scale, ln_bias) - - q = F.linear(ln1_out, q_weight, q_bias) - q = tensor.reshape(x=q, shape=[0, 0, num_heads, head_dim]) - q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) - k = F.linear(ln1_out, k_weight, k_bias) - v = F.linear(ln1_out, v_weight, v_bias) - k = tensor.reshape(x=k, shape=[0, 0, num_heads, head_dim]) - k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) - v = tensor.reshape(x=v, shape=[0, 0, num_heads, head_dim]) - v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) - - qk_out = layers.matmul( - x=q_out, y=k_out, transpose_y=True, alpha=head_dim**-0.5) - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) - attn_mask_out = qk_out + attn_mask - softmax_out = F.softmax(attn_mask_out) - else: - softmax_out = F.softmax(qk_out) - - if attn_dropout_prob: - dropout_out = F.dropout( - softmax_out, - attn_dropout_prob, - training=training, - mode="upscale_in_train") - qktv_out = tensor.matmul(dropout_out, v_out) - else: - qktv_out = tensor.matmul(softmax_out, v_out) - fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - out_linear_in = tensor.reshape( - x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - out = F.linear(out_linear_in, out_linear_weight, out_linear_bias) - residual_out = residual + F.dropout( - out, dropout_prob, training=training, mode="upscale_in_train") - final_out = F.layer_norm(residual_out, embed_dim, ln_2_scale, ln_2_bias) - return final_out - - -class TestFusedAttentionOpFp32(OpTest): +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedAttentionOp(unittest.TestCase): def setUp(self): self.config() - self.common_config() self.generate_input_data() + paddle.set_default_dtype(self.x_type) + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") def config(self): self.x_type = np.float32 + self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.training = True + self.batch_size = 8 self.query_length = 128 self.head_dim = 64 self.num_heads = 16 - - def common_config(self): - self.__class__.op_type = "fused_attention" - paddle.set_default_dtype(self.x_type) self.embed_dim = self.head_dim * self.num_heads - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length - self.attn_mask_type = np.float64 - self.training = True - self.need_weight = False + self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.weight_attr = None self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, @@ -151,52 +93,146 @@ def generate_input_data(self): raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") self.key, self.value = self.query, self.query - def compute_result(self): + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + residual = tensor_query + + for i in range(1): + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, + shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + if self.pre_layer_norm: + final_out = self.norm2(residual_out) + return final_out + + def GetFusedAttentionOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) - fused_attn = FusedMultiHeadAttention( - self.embed_dim, self.num_heads, self.dropout_prob, - self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, - self.need_weight, self.weight_attr, self.bias_attr) - out = fused_attn( - paddle.to_tensor(self.query), - paddle.to_tensor(self.query), - paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) - ref_out = GetBaselineOut(self.pre_layer_norm, self.training, - self.embed_dim, self.num_heads, self.head_dim, - self.query, self.attn_mask, - fused_attn.ln_scale.numpy(), - fused_attn.ln_bias.numpy(), - fused_attn.ln_2_scale.numpy(), - fused_attn.ln_2_bias.numpy(), - fused_attn.qkv_weight.numpy(), - fused_attn.qkv_bias.numpy(), - fused_attn.out_linear_weight.numpy(), - fused_attn.out_linear_bias.numpy(), - self.attn_dropout_prob, self.dropout_prob) - return ref_out, out + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 - def test_fused_attention_op(self): - ref_out, out = self.compute_result() - self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = F.fused_multihead_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, + out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out + def test_fused_attention_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) -class TestFusedAttentionOpFp16(TestFusedAttentionOpFp32): - def setUp(self): - self.config() - self.common_config() - self.generate_input_data() +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedAttentionOpFp16(TestFusedAttentionOp): def config(self): self.x_type = np.float16 + self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.training = True + self.batch_size = 8 self.query_length = 128 self.head_dim = 64 self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length def test_fused_attention_op(self): - ref_out, out = self.compute_result() - self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-2)) + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) if __name__ == "__main__": diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7965b362b9c55a..d8bec647f2c54c 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,6 +60,7 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 +from .common import fused_multihead_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index fdd370d7f81e72..319f36ac94384f 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1503,6 +1503,33 @@ def linear(x, weight, bias=None, name=None): return res +def fused_multihead_attention(x, + qkv_weight, + out_linear_weight, + pre_layer_norm=False, + ln_scale=None, + ln_bias=None, + ln_2_scale=None, + ln_2_bias=None, + epsilon=1e-05, + qkv_bias=None, + out_linear_bias=None, + src_mask=None, + dropout=0., + attn_dropout=0., + ln2_epsilon=1e-05, + name=None): + r""" + """ + if in_dygraph_mode(): + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, + out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + return final_out + + def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index 6f040e96c5ca70..0084f7ff339df3 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -12,155 +12,115 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle -from ...fluid.layers import core -from ...fluid.framework import in_dygraph_mode -from paddle import _C_ops -from paddle.fluid.layer_helper import LayerHelper -import copy -from .. import functional as F -from paddle.nn import Layer -from ...framework import ParamAttr -from ...framework import get_default_dtype, set_default_dtype -from paddle.nn.layer.transformer import _convert_attention_mask -from ...fluid.data_feeder import check_variable_and_dtype, check_dtype -from ..initializer import Constant - -import collections - - -def fused_multihead_attention(x, - qkv_weight, - out_linear_weight, - pre_layer_norm=False, - ln_scale=None, - ln_bias=None, - ln_2_scale=None, - ln_2_bias=None, - epsilon=1e-05, - qkv_bias=None, - out_linear_bias=None, - src_mask=None, - dropout=0., - attn_dropout=0., - ln2_epsilon=1e-05, - name=None): - r""" - """ - if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, - out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, - 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) - return final_out - class FusedMultiHeadAttention(Layer): """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + + Please refer to `Attention Is All You Need `_ + for more details. + + Parameters: + embed_dim (int): The expected feature size in the input and output. + num_heads (int): The number of heads in multi-head attention. + dropout (float, optional): The dropout probability used on attention + weights to drop some attention targets. 0 for no dropout. Default 0 + kdim (int, optional): The feature size in key. If None, assumed equal to + `embed_dim`. Default None. + vdim (int, optional): The feature size in value. If None, assumed equal to + `embed_dim`. Default None. + need_weights (bool, optional): Indicate whether to return the attention + weights. Default False. + weight_attr(ParamAttr, optional): To specify the weight parameter property. + Default: None, which means the default weight parameter property is used. + See usage for details in :code:`ParamAttr` . + bias_attr (ParamAttr|bool, optional): To specify the bias parameter property. + Default: None, which means the default bias parameter property is used. + If it is set to False, this layer will not have trainable bias parameter. + See usage for details in :code:`ParamAttr` . + + Examples: + + .. code-block:: python + + import paddle + + # encoder input: [batch_size, sequence_length, d_model] + query = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, num_heads, query_len, query_len] + attn_mask = paddle.rand((2, 2, 4, 4)) + multi_head_attn = paddle.nn.MultiHeadAttention(128, 2) + output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] """ - # todo (@limin, do we need cache in FusedMultiHeadAttention layer?) - # Cache = collections.namedtuple("Cache", ["k", "v"]) - # StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + Cache = collections.namedtuple("Cache", ["k", "v"]) + StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) def __init__(self, embed_dim, num_heads, dropout=0., - attn_dropout=0., kdim=None, vdim=None, - normalize_before=False, need_weights=False, weight_attr=None, - bias_attr=None, - name=None): + bias_attr=None): super(FusedMultiHeadAttention, self).__init__() - - assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " - "but recieved {}".format(embed_dim)) - assert num_heads > 0, ("Expected nhead to be greater than 0, " - "but recieved {}".format(num_heads)) - - attn_dropout = dropout if attn_dropout is None else attn_dropout - self.normalize_before = normalize_before - self._dtype = self._helper.get_default_dtype() - self._weight_attr = weight_attr - self._bias_attr = bias_attr - - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" - - ## linear parameters. - self.qkv_weight = self.create_parameter( - shape=[3, num_heads, self.head_dim, embed_dim], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - self.qkv_bias = self.create_parameter( - shape=[3, num_heads, self.head_dim], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True) - self.out_linear_weight = self.create_parameter( - shape=[embed_dim, embed_dim], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - self.out_linear_bias = self.create_parameter( - shape=[embed_dim], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True) - - if get_default_dtype() == 'float16': - set_default_dtype('float32') - ## layer_norm parameters. - self.ln_scale = self.create_parameter( - attr=self._weight_attr, - shape=[embed_dim], - default_initializer=Constant(value=1.0)) - self.ln_bias = self.create_parameter( - attr=self._bias_attr, shape=[embed_dim], is_bias=True) - self.ln_2_scale = self.create_parameter( - attr=self._weight_attr, - shape=[embed_dim], - default_initializer=Constant(value=1.0)) - self.ln_2_bias = self.create_parameter( - attr=self._bias_attr, shape=[embed_dim], is_bias=True) - if get_default_dtype() == 'float16': - set_default_dtype('float16') - - ## dropout parameters - self.dropout = dropout - self.attn_dropout = attn_dropout - - self.name = name + raise NotImplementedError() def forward(self, query, key=None, value=None, attn_mask=None, cache=None): """ + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + Parameters: + query (Tensor): The queries for multi-head attention. It is a + tensor with shape `[batch_size, query_length, embed_dim]`. The + data type should be float32 or float64. + key (Tensor, optional): The keys for multi-head attention. It is + a tensor with shape `[batch_size, key_length, kdim]`. The + data type should be float32 or float64. If None, use `query` as + `key`. Default None. + value (Tensor, optional): The values for multi-head attention. It + is a tensor with shape `[batch_size, value_length, vdim]`. + The data type should be float32 or float64. If None, use `query` as + `value`. Default None. + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): + It is a namedtuple with `k` and `v` as fields, and stores tensors + shaped `[batch_size, num_heads, length, embed_dim]` which are results + of linear projection, reshape and transpose calculations in + MultiHeadAttention. If it is an instance of `Cache`, `k` and `v` + fields reserve intermediate results of previous positions, which + mostly used for decoder self attention. If it is an instance of + `StaticCache`, `key` and `value` args would be ignored, `k` and + `v` fields would be used as calculated results on `key` and + `value`, which mostly used for decoder-encoder cross attention. + It is only used for inference and should be None for training. + Default None. + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `query`, representing attention output. Or a tuple if \ + `need_weights` is True or `cache` is not None. If `need_weights` \ + is True, except for attention output, the tuple also includes \ + the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \ + If `cache` is not None, the tuple then includes the new cache \ + having the same type as `cache`, and if it is `StaticCache`, it \ + is same as the input `cache`, if it is `Cache`, the new cache \ + reserves tensors concatanating raw tensors with intermediate \ + results of current query. """ - if attn_mask is not None: - # Support bool or int mask - attn_mask = _convert_attention_mask(attn_mask, query.dtype) - out = fused_multihead_attention( - x=query, - qkv_weight=self.qkv_weight, - out_linear_weight=self.out_linear_weight, - pre_layer_norm=self.normalize_before, - ln_scale=self.ln_scale, - ln_bias=self.ln_bias, - ln_2_scale=self.ln_2_scale, - ln_2_bias=self.ln_2_bias, - epsilon=1e-05, - qkv_bias=self.qkv_bias, - out_linear_bias=self.out_linear_bias, - src_mask=attn_mask, - dropout=self.dropout, - attn_dropout=self.attn_dropout, - ln2_epsilon=1e-05) - return out + raise NotImplementedError() class FusedFeedForward(Layer): diff --git a/tools/__pycache__/static_mode_white_list.cpython-37.pyc b/tools/__pycache__/static_mode_white_list.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04711804109d2b5acb2ddf6acefee91dfd28780b GIT binary patch literal 21041 zcmeI4b(}0$mB(MQND>GH4;CB}APMii5ZnR*f&@vhL5i9_)jh>@cXg_|M}oUUaCdii zUuWHS$@;SH&MqtL_sFfP?w;^j_TMG>{PJF1zvmvmw{y;kC!TPKe;)Ro3-WJ0^3b7g z>qq+Mo5vnHbWc8-Q;t7$NDj#{a;*G_Tuv@8SCA{padNzzAXkzr%ZYLoxvHEbSCgyD zHRPIdExEQ_M^2XO%Jt;>as#=coFX@p8_P}PrgAg6x!gi-DYueS=xAJ%L_wo<&kMd9Q&+;$w zuks!FH~Dw@uKb7mr~H@vw|wuCW8@Gx1{@211Y8bW9$W!j5gZ4O2Pc3lfh&U(!BxOj z!Aan1;OgKS;F{oC;M(9i;AC)Ja6NE+a075da0<8)xG}g1xGA_9xH-55xFxt1I2D`* zZVhe&ZVOHaXMi)o?Z9Dh1RMotf!l+#!5zRI!JWXJ!Ck;z!QH^!!9Bn|!M(sa;NIXq z;9PKDa6fQ=a2|L7cp!KXcrZ90JOn%xJPbS>TmUWvj{uJZj{=Vdj{zItvEXsw@!$y{ z15X4w*aQP0pa4TK0$ZR2PXgOu3?^U-Dlh{Vff_Vm4i;bs?1DY84=x5z22TM`1y2J{ z2hRY{1kVD`2G0S{1%VfcJv;f%k(CfDeKXfe(X^fRBQYfscbvfKP%? zflq_afX{-@fzN|4fG>hCfiHuvfUknDfvmGB^=j1zZ)J1g-|I4z2;N39bdM z4Xy)DR$IcK*9F(Zo@k%zgBySwVqJ8cQ^1YX%JAon!A)$moAS4tft!O{fLnrFfm6Y0 z;MU+a;I`m&a0WON+zuQDN5D~V7Pvh)8{7fh5!?yf8QcZj72FNn9oz%l6Wj}&1MUs( z1I`8a1@{B@2j_tYfCqvHfd_;0!9&19!Nb7A!3E$#@Cfio@F?(T@EEWG9t$1^9uJ-X zGVnx@gH1310tzq$Bd`TZ@FcJe#$W=bpaL^+5vV}}=3oJKz%JMW``}{mWbhR5RPZ$L zbnp!DOzbeDDJBLhvH+V(=31Qt&eHa&QTF1$ZTR6?ipx4R|eh9e6!> z19&5N6L>Rt3wSGd8+bc-2Y4rV7kD>#4|p$lA9z3b0QeyI5cn|o2>2-Y82C8&1o$NQ z6!%E-vb- z$;Rb&aipsmGh<%soaHnd)I+Vgn)Ml1lGR01*u>Y2qH23F%jWs!xX^~zT`%1n9zI*E zZm?dpb0kZ(CtB?+$8XcdSGCWEP0ov23w*k!xM)#K2T;TauyJ*7)+^mvEt^%I}h*w+$gg?unCck+2zO|$7V8`Zg_rZ}y| zc%ExhCvPsf)@#9)eKg4@`gOzp*SwsRFiJL@9X-N-&OWQ_J<0d7d{zciGMpCB-}k5A zyZY$_;&aB*V9=dWrl)dPR%^6-)TAW{1x;z&;ovJNrz0Kmda*)fNVZig(Tb_iWWSwV zCbpStcvF|8QC-@8qi{SPj4NbjobMO)#!bEDWCrsWgY692E9a~iR^d>c!DBDe!B$mg zqhear`JBJ8q$_#vI-ib;jTpZkWfP<*$P!eqM@?0m48>q|o$uQ|vfWxIWH1-=ylDz= z*?yO!sn)wVP_N%R`FLTyLTMl~Q={0pnHvdWX^Vk&4L0GUftc)SrrbrvjeG+;6(mBH zz-GGtG9Q;08$ShMH}k>vZeC05aWikqK~}?S(|l5d^lj$z!B#e{YMcIAHt-&7F$S5{Mn>JGYehJAVDewW4VNBoqYUz<2wCT^ zwo>mtpbgX6K}5+`H7=TM(=0LI)?&Jy6>ytQ;xaz4GCrLlZ5x~W?a8}ZNt<(-9#rFT z0hNZcsZtE_&c}_8=;t}8CJ?Pa9-yTizO#>(mt#1UK4`7UM(n7XO51ZPrYTt9=C~Sc zr$trlZEwyQO!E<9?li`N>>Ng`n&i8tjiNas+_0>q7P|J>CU5WJsO{Zcb^ZCi?EXjB z`%}c$yZ(Hy_IC!*$(AaWwT!8|^|czm8?QH7G6!**jF&ZA<$YyiO#=Eii5jbsXo+Y3goVS_k?SlF^+m4o8#0Y*g;|M3Tl`xwBy$`%(qt z1m9UPpK%iNL57|hP1E;c(kyzaZ?K~uAq0}o^Psr?bzGSNLFJ593*+8|==UnOi-Mii zNMyJ%y*h-VtM%1MvAGy2tGTX)Qg9jmMrUtb^D7A#Gh=Eu0Cc&T3F-kgna=5!DJ`%V zqG8STQtrrh4lDIt%A~8CHFZ^*UFEbwJ7OuZs{062iQ-ZXM9*8t)EERDFBUv0CRME( zzD7%ZGuqB%QZxhfT{SI28A>tF2V3pO4wffm zR0{pbMmRtcUWb(VC@UtLg>2l|my1tSmzxXNbWNkUB~Y))TVTf#u6Z%q4;|T7;Vw04 z1yxTkO0_$kRg1bANpPOb_`6aocB2|>VLjYn$UfR>wv|m5<9V5JKxQjlKxBWamU6=- zny-8kwzHWdEb5G@vQEdlKUK4uO5^B*bG#0BaBjA5wwQ0+!st<`jzc)dg}DGQ#)?vO z+qKfn+7~L}O^C)ygHgT6rig}n@hg?)XK3w#W8)^9+on7BGh68nyDjRPjVoCXoM3kg z)vvX#6biPqmmNK*Eb~_uEF{Wf#b@E3LL>NEuGV zH0;O5ZG6vZ^tzZY>QqTP7-y%b)p=G8=k~CT@pO~!opeR+#kuW>Qd%u_^d|1E75hdi zKayb%;d;uAj!qM+X!Rk*C)lco=5@Pd*uFYFj{VJN#C-d+BBX9usN<}Gc393_gM+ZB zUb&b~ih2aieED9era4xjfqg#h4!Bp0GSC3h#38OUOEpvO!W|_VqR-3GqFOYW#Y7#mr4}8DjD4Lujg>i5HQ9&Ly0kuUG`Q@+{oY|QFUP_CT|pCm<`4%x5rbtE zyH&jn7hvbrri6$L8}omHwDa0T+_`GF;8>E(MFTB@25v5giz$(NHO0?TV}2yrx3%)l zqdU0g(aY%BcOJtF7^;ZotRkH48Bs;rq*f)vPiUhx9v@q$x)?9mK8Xs;Iq8bM%}t!= z=`ES6HBU)(ISKabc2b9-$z+pg?;`U~lolZ-sV)>5B9ubeV3I|n2@W+%baH=_0kteO zY)8_l|Rr&qfhZNM6A51Jq$f0o*~T8fcrqsfWLJ?X^lW@8AihH)7y@^fzL z(+H8uAsWmg8|P~uO?N6A*^&AM@nw1Y;K%{CFqX5j&p%U@fY zf~#H&C;eIr3%}f1>QWk=iyO|;NM&%W-ky$vl3O*O1pW5byb)OD<@U2Ep&)tXfe z1%vr471QJs{NF4wd5Cqd%0l2Y2+ndFTDn>r)xENeq=GB#U<_AL`k_Lvt9Hp#4WeH@ z?iNPFa}ZkQ;<|;;Xklir*5DEW zGff+5F3hqgA;-O9A)*)zgHXNlVNyNq{Ow4ie0o43E|)WDZx9W5P+I0W5=0V8yLnw| z(i$KVl1qoWzNl79gqrP*Gn;Rk!+R^!3hNb_X2Rf>#2+mM+ii8e;i7@8;VAS*s*7@p zmp!lq`1($k)~?yr&bDrxvAQaVLQ)Tj+bzk!rbT6B7W46Ze|>O^)HMc!F;;it*14Px zYG0YVASF4SKAH|0cUs|hL!!;aP(Pb-HE$B3wSlF>CnSr;Sy_eo zYV&rO#9-8^NiI>EIjWu?C{v+`o^lF};;^E!sfKg)dyOPZVR{%(Mj`%bX2l>p5OWbU zOP%t6?nA5=tF*fFSx)iZ%-n?Zm8>R5I+RnBa8wITO^~_6%d@0Z@O8Z&r4G@Drbk^} zmg`N-oVj)=Nk^DN4Bb;nBz`Gz4Y57efI*XEdOEH{wrk7=U%rL$;dZxvq^6a#W>Qt2 z%lA&>9DCOeYq_MhUtFE2p762>(-IBP+>J=Dx!sy{^{!TW^#qepcWp$qxo)PNGGme+ zQ%&J$P)v?{i=5jSo{8~17uqFO{Tw4jx|9#~v(1e%WjaWa37!$G3Di|R$^FcVi%yabZP4ap~ zK#AGMx(CxaBjPxy;o=Ml6JmLk5rFB?Zo1;zr*6B|G+Z*##dgbX^J|{iQr)p%G51xld_>;C*R~fh+0h#k~REvUYQYp_e5{ zTt(#P0X+&c*j2Az_40Q~!hDD1SS;``Wh@pq`ktm5UBD_4hod%`o|V672e_@GO-#=1!hh0+ynxMk{paUH@LrQC*S z+T0n#=Rc?k+2X}C{E_u~6I*uH=ykDP%&Q0=F=uDJG2T-+5#CKNtl$~-{zKob_nAx? zdE@lU7?GuMo>0YxPAQ-bcPrKhvr#<3hX+zlH;yik+cdelVx&~R-%{P(($OacsfzY& zd^cS@)~>>Kxsr(N`E%hnF`zY!+XPp7>OJ6(m1DgRaE_dndmLi3g_AZeg6n71cz=YG zpxO0y_p9@hfyj=J0QD1d{z288g-4Je?2>q84ZBSRUSV>jpcQPY9|Q^O6D zHF~&G(>#A^L9-_OJsIHwNDc8l4c(dX&(7aSjnGh?-C4RyNG&QdM@4<9Ylw7n7gY3n zJY+{w_A6xzXRwgQMsd;XsyS7x+(fT-$ebcd8;AU=??r_-BWV4)F;1|AISEnAzxY#jnLC=RS z?&@dyUcd}(I266A1;%l9zWBZBH0kT$^sO!_CCcrVx1k%Y%L ziz2qKNr^arKWr^9569^oNejCLWXaB>T8Lt@T!vGb5~4ci2>#wVJW=G<)34=HQSY|E z#eVg!IKDfE^#h9H73bN#Nz=Yi#l?%=SwVUW0r-g&q_LwrY*oCq?1wZ53nALYEP(xh1qqR4wx8%A|rx2<}XOv}#P7#F56` zqL#RrV;{|U2lcqMK>hBt&HYoG$BK9dK-{gFB+Z!k`S-mK|yHoS%i*29z_`&#R$40(jFJ0jxr{oZ<2=y}4)*FM?Fp68^z8An%Wya%&>*RGACeb8dT09{kX>$c5RI)y<8y z=7ruUk`J2`s3I(he^MafG?q)N>P|6i;MT z717HL)Iy^q=dqZ3hI~fuU3I5=BD8i=qn6d>ue#6DDD7H!wjlN#O$lhST8dvGji}RS4>>O?Nj%$$ad33iZFJmNMg}!m6GG2Skw=~)hs#}@YVq31sPR4kvOZ>ZML33`xyHeN2ls4PEL+<+ZuKmq|t^6LP8^A*rYqMGA zyae>Si>`sIK5w9TM^4uUKh@aJ%KNXNxP(7q0?@Ve%*{b9R~oveIj~_jntK~dLG zH&@WdxA7paKgQsFL#tYRjnr#&30ucws5HDiU2^8&S^=M*|LRRb_)o2Tws#Qj1Zbi> zOsyJmI|?4OxXJAO4xn>wUCY~>^sdkiQ>+CmddI|8%)3;#;tV@2qTd0->7(tHF70qy zw=z{boYr;2bu3SDcS2nQcT{!rOAS_L?nWn>*L8MU^`4n^Req)`($eed{J&biGVS^x z`XJ@|O``WCUH$qbTE8+$D`Z{2R+?6tx~5!ljk<3Bv}L6lN4RU@8(MX`x_bAQSL+WL zB7{MezzQR@!lc5Y^(%YuGwb#5W_SLIDtn^T)j1>e>sIDnZ?H14FKjge=w@&`qZg1} zne$x;ZddFJjpCJQ+!jaO66HLyR1R2KdUUb3&`PmxqSYGqzjQZkmuN<1)KX8;`2r(5 zzB1cV-OA#9ztOK-hOZp`i$hEHpsVo1th)#9U6u22@(Q|gSCpxIW&ZF=xUcLwL_kj2 zN@JIXPN}TybG3S9DEFI%GPP@gDDRqkm=eXEDTqW> z*Y1q|u+gqtW%y*hqjuvI<9gODdUazuDWXk!(9%_Vyyqqd6OQ3=ZHTnh`ADn%b+Iu0MR$KlN^l+;1p1dz#9fQ(dO_FFHF01_$0PZ8(iyY;0{8$t|HG z+k^L&*57lxCRp3u_OF*CT`3C_*Glsdw=j`}auPAGm7R5My|O-?OEtAuT*@ zQ4?wD*PzxlV!HMBHFl~I7dc(${El(S+e)6zJTu`lCjr8z_Y}H4_Yk}O{%gJ2c){z4 zxct;%I~KZ3@5u9r_T&K)xD?n0Ui<1Powz?Jv;zoEsrcn*L(Zr3d3>~Bi6C@G(D+$p zxa?luzEI{WZa2bu=Lcxgx73~M%{rY#v{P2R%w5v3b>AbipN=V;j#I*Bx}(8t#3!?I zR%o^70jjKhVP?Xe1{T2I-Nv(j>J znX*_~;&!UNiRtz3IqplneJ)qU7?t5O=?@R+#ro~yqc>JOjgIC% zeycjIohyhav*f51`_@9{t;A#B-t!)^}v3oN5N86~y<(KHrdC7$o)H5osRV40Ch zL*cZ)5NdUYKY5B@0e8<+{h6TsNJM{lXTMJ9^q^DNYV`*v4Cd|Q{o3-aPc=LLv|}zk z`GO1Yb>VsU&K~@b`<$Cyct&tA}TpDQiz&M=C1aP+5V-+-IYRKGv56q{ojsV t!r;dq^TU7tKU3gx$Ncc$A5!24p@2$@woiX{e@i7^|AUPA!?uY6{{!j7SQ-ET literal 0 HcmV?d00001 From 7b28f7cfd1957a6cc4346f155c42f626b09c4d94 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Wed, 13 Oct 2021 12:40:07 +0000 Subject: [PATCH 19/29] Move fused_multi_head_attention from common.py. --- python/paddle/nn/functional/__init__.py | 2 +- python/paddle/nn/functional/common.py | 27 -------- .../paddle/nn/functional/fused_transformer.py | 67 +++++++++++++++++++ 3 files changed, 68 insertions(+), 28 deletions(-) create mode 100644 python/paddle/nn/functional/fused_transformer.py diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d8bec647f2c54c..830cf5dbc489b6 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,7 +60,7 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 -from .common import fused_multihead_attention # noqa: F401 +from .fused_transformer import fused_multihead_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 319f36ac94384f..fdd370d7f81e72 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1503,33 +1503,6 @@ def linear(x, weight, bias=None, name=None): return res -def fused_multihead_attention(x, - qkv_weight, - out_linear_weight, - pre_layer_norm=False, - ln_scale=None, - ln_bias=None, - ln_2_scale=None, - ln_2_bias=None, - epsilon=1e-05, - qkv_bias=None, - out_linear_bias=None, - src_mask=None, - dropout=0., - attn_dropout=0., - ln2_epsilon=1e-05, - name=None): - r""" - """ - if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, - out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, - 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) - return final_out - - def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py new file mode 100644 index 00000000000000..53f9a992dbb6e3 --- /dev/null +++ b/python/paddle/nn/functional/fused_transformer.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import paddle +from ...fluid.framework import in_dygraph_mode, default_main_program +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.tensor import fill_constant +from ...tensor import concat +from ...tensor.creation import zeros +from paddle.static import Variable +from ...fluid.layers import core +from ...fluid import dygraph_utils +from ...fluid.layers import unfold # noqa: F401 +from ...tensor.manipulation import squeeze +from ...tensor.manipulation import unsqueeze +from ...tensor import clip +from ...tensor import sum +from ...tensor import sqrt +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype +from ...fluid.framework import in_dygraph_mode, _varbase_creator + +from ...fluid.framework import in_dygraph_mode +from ...fluid import core, dygraph_utils +from ...fluid import core, layers +from ...fluid.data_feeder import check_variable_and_dtype +from paddle import _C_ops + +__all__ = [] + + +def fused_multihead_attention(x, + qkv_weight, + out_linear_weight, + pre_layer_norm=False, + ln_scale=None, + ln_bias=None, + ln_2_scale=None, + ln_2_bias=None, + epsilon=1e-05, + qkv_bias=None, + out_linear_bias=None, + src_mask=None, + dropout=0., + attn_dropout=0., + ln2_epsilon=1e-05, + name=None): + r""" + """ + if in_dygraph_mode(): + ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( + x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, + out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + return final_out From 30fef54352604663efab3bb6bada4ede4cc2a1b3 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 04:42:02 +0000 Subject: [PATCH 20/29] Modify copyright and names with number. --- .../operators/fused/fused_attention_op.cc | 27 +++++++------------ .../operators/fused/fused_attention_op.cu | 11 +++++--- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 8e5263091e48e7..9247ca0334a44a 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -115,7 +118,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - if (ctx->Attrs().Get("is_test1") == false) { + if (ctx->Attrs().Get("attn_dropout_is_test") == false) { ctx->SetOutputDim("AttnDropoutMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); } @@ -220,20 +223,20 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { platform::errors::InvalidArgument( "'attn_dropout_prob' must be between 0.0 and 1.0.")); }); - AddAttr("is_test1", + AddAttr("attn_dropout_is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("fix_seed1", + AddAttr("attn_dropout_fix_seed", "A flag indicating whether to use a fixed seed to generate " "random mask. NOTE: DO NOT set this flag to true in " "training. Setting this flag to true is only useful in " "unittest or for debug that always the same output units " "will be dropped.") .SetDefault(true); - AddAttr("seed1", "Dropout random seed.").SetDefault(0); + AddAttr("attn_dropout_seed", "Dropout random seed.").SetDefault(0); AddAttr( - "dropout_implementation1", + "attn_dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" "There are two kinds of ways to implement dropout" "(the mask below is a tensor have the same shape with input" @@ -280,19 +283,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "dropout_implementation", "[\"downgrade_in_infer\"|\"upscale_in_train\"]" - "There are two kinds of ways to implement dropout" - "(the mask below is a tensor have the same shape with input" - "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" - "1. downgrade_in_infer(default), downgrade the outcome at inference " - "time" - " train: out = input * mask" - " inference: out = input * (1.0 - dropout_prob)" - "2. upscale_in_train, upscale the outcome at training time, do nothing " - "in inference" - " train: out = input * mask / ( 1.0 - dropout_prob )" - " inference: out = input" - " dropout op can be removed from the program. the program will be " - "efficient") + "The meaning is the same as 'attn_dropout_implementation'.") .SetDefault("downgrade_in_infer") .AddCustomChecker([](const std::string &type) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index e99fc1c7b94af4..ccb76a547d93e3 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -1,8 +1,11 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -88,14 +91,14 @@ class FusedAttentionOpKernel : public framework::OpKernel { const float ln2epsilon = ctx.Attr("ln2epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); - bool is_test_1 = ctx.Attr("is_test1"); + bool is_test_1 = ctx.Attr("attn_dropout_is_test"); auto &dropout_implementation_1 = - ctx.Attr("dropout_implementation1"); + ctx.Attr("attn_dropout_implementation"); bool is_upscale_in_train_1 = (dropout_implementation_1 == "upscale_in_train"); auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - bool is_fix_seed_1 = ctx.Attr("fix_seed1"); - int seed_val_1 = ctx.Attr("seed1"); + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + int seed_val_1 = ctx.Attr("attn_dropout_seed"); // final output. auto *out = ctx.Output("Y"); From 766ef85f7337a92822e9b82d94aad89d76aa0868 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 04:58:41 +0000 Subject: [PATCH 21/29] Remove HIP and use OpTest and remove print. --- .../fluid/operators/fused/fused_attention_op.cu | 15 +-------------- python/paddle/fluid/framework.py | 1 - .../tests/unittests/test_fused_attention_op.py | 8 +++----- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index ccb76a547d93e3..b251aedee1d05e 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -12,26 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef __NVCC__ +#include #include -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cuda_device_function.h" - -#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" -#endif -#ifdef PADDLE_WITH_HIP -#include "paddle/fluid/platform/miopen_helper.h" -#endif -#include #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b6241f6e5299df..4aa5e680ee0d89 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6063,7 +6063,6 @@ def __deepcopy__(self, memo): return new_param def _copy_to(self, device, blocking): - print("in ParamBase copy_to func") state = copy.deepcopy(self.__dict__) new_param = ParamBase(self.shape, self.dtype, **state) core.varbase_copy(self, new_param, device, blocking) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index bf26e05c844e49..1f366763423e01 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -24,15 +24,15 @@ from paddle import tensor from paddle.fluid import layers import unittest +from op_test import OpTest -@unittest.skipIf(not core.is_compiled_with_cuda(), - "Paddle core is not compiled with CUDA") -class TestFusedAttentionOp(unittest.TestCase): +class TestFusedAttentionOp(OpTest): def setUp(self): self.config() self.generate_input_data() paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_attention" self.q_proj = Linear( self.embed_dim, self.embed_dim, @@ -206,8 +206,6 @@ def test_fused_attention_op(self): final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) -@unittest.skipIf(not core.is_compiled_with_cuda(), - "Paddle core is not compiled with CUDA") class TestFusedAttentionOpFp16(TestFusedAttentionOp): def config(self): self.x_type = np.float16 From 0bc03a6b302d00a120b014760a793521df161c95 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 05:11:39 +0000 Subject: [PATCH 22/29] Minors. --- .../operators/fused/fused_attention_op.cu | 18 ++-- .../unittests/test_fused_attention_op.py | 90 +++++++++---------- 2 files changed, 50 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index b251aedee1d05e..a8db5b25729e73 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -107,7 +107,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. - auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); @@ -150,14 +149,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { int output_size = 3 * hidden_size; int input_size = dim_embed; - bool transA = false; - bool transB = true; - bool compute_bias = true; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); - auto qkv_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + // (transA, transB, compute_bias) = (false, true, true) + auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, + bsz_seq, output_size, input_size, true); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, @@ -167,12 +163,10 @@ class FusedAttentionOpKernel : public framework::OpKernel { dim_head, attn_dropout_param); output_size = hidden_size; - transA = false; - transB = false; - compute_bias = false; + // (transA, transB, compute_bias) = (false, false, false) auto out_linear_compute = - AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, + output_size, input_size, false); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 1f366763423e01..1d0bd46e7c7988 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -102,52 +102,50 @@ def GetBaselineOut(self): attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) residual = tensor_query - for i in range(1): - ln1_out = tensor_query - if self.pre_layer_norm: - ln1_out = self.norm1(tensor_query) - - q = self.q_proj(ln1_out) - q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) - q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) - k = self.k_proj(ln1_out) - v = self.v_proj(ln1_out) - k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) - k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) - v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) - v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) - - qk_out = layers.matmul( - x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) - - if attn_mask is not None: - attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) - attn_mask_out = qk_out + attn_mask - softmax_out = F.softmax(attn_mask_out) - else: - softmax_out = F.softmax(qk_out) - - if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train") - qktv_out = tensor.matmul(dropout_out, v_out) - else: - qktv_out = tensor.matmul(softmax_out, v_out) - - fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) - out_linear_in = tensor.reshape( - x=fmha_out, - shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) - out = self.out_proj(out_linear_in) - - residual_out = residual + self.dropout(out) - if not self.pre_layer_norm: - final_out = self.norm1(residual_out) - if self.pre_layer_norm: - final_out = self.norm2(residual_out) + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + if self.pre_layer_norm: + final_out = self.norm2(residual_out) return final_out def GetFusedAttentionOut(self): From 99e36f9ef32d400f7f4cb21456efb53fbd3f8d56 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 05:47:16 +0000 Subject: [PATCH 23/29] Polish functional.fused_attention_op. --- .../paddle/nn/functional/fused_transformer.py | 71 ++++++++++--------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index 53f9a992dbb6e3..05cfe1cf2c126e 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -12,29 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +#import warnings import paddle -from ...fluid.framework import in_dygraph_mode, default_main_program -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.layers.tensor import fill_constant -from ...tensor import concat -from ...tensor.creation import zeros -from paddle.static import Variable -from ...fluid.layers import core -from ...fluid import dygraph_utils -from ...fluid.layers import unfold # noqa: F401 -from ...tensor.manipulation import squeeze -from ...tensor.manipulation import unsqueeze -from ...tensor import clip -from ...tensor import sum -from ...tensor import sqrt -from ...fluid.data_feeder import check_variable_and_dtype, check_dtype -from ...fluid.framework import in_dygraph_mode, _varbase_creator +#from ...fluid.framework import in_dygraph_mode, default_main_program +#from paddle.fluid.layer_helper import LayerHelper +#from paddle.fluid.layers.tensor import fill_constant +#from ...tensor import concat +#from ...tensor.creation import zeros +#from paddle.static import Variable +#from ...fluid.layers import core +#from ...fluid import dygraph_utils +#from ...fluid.layers import unfold # noqa: F401 +#from ...tensor.manipulation import squeeze +#from ...tensor.manipulation import unsqueeze +#from ...tensor import clip +#from ...tensor import sum +#from ...tensor import sqrt +#from ...fluid.data_feeder import check_variable_and_dtype, check_dtype +#from ...fluid.framework import in_dygraph_mode, _varbase_creator from ...fluid.framework import in_dygraph_mode -from ...fluid import core, dygraph_utils -from ...fluid import core, layers -from ...fluid.data_feeder import check_variable_and_dtype +#from ...fluid import core, dygraph_utils +#from ...fluid import core, layers +#from ...fluid.data_feeder import check_variable_and_dtype from paddle import _C_ops __all__ = [] @@ -42,26 +42,29 @@ def fused_multihead_attention(x, qkv_weight, - out_linear_weight, + linear_weight, pre_layer_norm=False, + pre_ln_scale=None, + pre_ln_bias=None, ln_scale=None, ln_bias=None, - ln_2_scale=None, - ln_2_bias=None, - epsilon=1e-05, + pre_ln_epsilon=1e-05, qkv_bias=None, - out_linear_bias=None, - src_mask=None, - dropout=0., - attn_dropout=0., - ln2_epsilon=1e-05, + linear_bias=None, + attn_mask=None, + dropout_rate=0.5, + attn_dropout_rate=0.5, + ln_epsilon=1e-05, name=None): r""" """ if in_dygraph_mode(): - ln_mean, ln_variance, ln_out, qkv_out, qkv_bias_out, transpose_out_2, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, src_mask_out, fmha_out, out_linear_out, dropout_mask_out, ln2_mean_out, ln2_var_out, bias_dropout_residual_out, final_out = _C_ops.fused_attention( - x, ln_scale, ln_bias, qkv_weight, qkv_bias, src_mask, - out_linear_weight, out_linear_bias, ln_2_scale, ln_2_bias, - 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_prob', dropout, 'attn_dropout_prob', attn_dropout) + # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, + # attn_dropout_out, attn_mask_out, fmha_out, linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( + x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, + linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', + pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_prob', + dropout_rate, 'attn_dropout_prob', attn_dropout_rate, 'ln2epsilon', + ln_epsilon) return final_out From 2d9f7278c035b4fb0a1f37bb398eef061988b249 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 05:54:11 +0000 Subject: [PATCH 24/29] Minors. --- .../paddle/nn/functional/fused_transformer.py | 26 +++---------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index 05cfe1cf2c126e..e569dd02f8fe8e 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -12,29 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -#import warnings import paddle -#from ...fluid.framework import in_dygraph_mode, default_main_program -#from paddle.fluid.layer_helper import LayerHelper -#from paddle.fluid.layers.tensor import fill_constant -#from ...tensor import concat -#from ...tensor.creation import zeros -#from paddle.static import Variable -#from ...fluid.layers import core -#from ...fluid import dygraph_utils -#from ...fluid.layers import unfold # noqa: F401 -#from ...tensor.manipulation import squeeze -#from ...tensor.manipulation import unsqueeze -#from ...tensor import clip -#from ...tensor import sum -#from ...tensor import sqrt -#from ...fluid.data_feeder import check_variable_and_dtype, check_dtype -#from ...fluid.framework import in_dygraph_mode, _varbase_creator - from ...fluid.framework import in_dygraph_mode -#from ...fluid import core, dygraph_utils -#from ...fluid import core, layers -#from ...fluid.data_feeder import check_variable_and_dtype from paddle import _C_ops __all__ = [] @@ -59,8 +38,9 @@ def fused_multihead_attention(x, r""" """ if in_dygraph_mode(): - # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, qktv_out, softmax_out, attn_dropout_mask_out, - # attn_dropout_out, attn_mask_out, fmha_out, linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out + # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, + # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, + # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', From f35b3c7af3278eccfc251013fa2dc597041806f9 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 06:04:18 +0000 Subject: [PATCH 25/29] Remove commits of tools/__pycache__/. --- .../static_mode_white_list.cpython-37.pyc | Bin 21041 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tools/__pycache__/static_mode_white_list.cpython-37.pyc diff --git a/tools/__pycache__/static_mode_white_list.cpython-37.pyc b/tools/__pycache__/static_mode_white_list.cpython-37.pyc deleted file mode 100644 index 04711804109d2b5acb2ddf6acefee91dfd28780b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21041 zcmeI4b(}0$mB(MQND>GH4;CB}APMii5ZnR*f&@vhL5i9_)jh>@cXg_|M}oUUaCdii zUuWHS$@;SH&MqtL_sFfP?w;^j_TMG>{PJF1zvmvmw{y;kC!TPKe;)Ro3-WJ0^3b7g z>qq+Mo5vnHbWc8-Q;t7$NDj#{a;*G_Tuv@8SCA{padNzzAXkzr%ZYLoxvHEbSCgyD zHRPIdExEQ_M^2XO%Jt;>as#=coFX@p8_P}PrgAg6x!gi-DYueS=xAJ%L_wo<&kMd9Q&+;$w zuks!FH~Dw@uKb7mr~H@vw|wuCW8@Gx1{@211Y8bW9$W!j5gZ4O2Pc3lfh&U(!BxOj z!Aan1;OgKS;F{oC;M(9i;AC)Ja6NE+a075da0<8)xG}g1xGA_9xH-55xFxt1I2D`* zZVhe&ZVOHaXMi)o?Z9Dh1RMotf!l+#!5zRI!JWXJ!Ck;z!QH^!!9Bn|!M(sa;NIXq z;9PKDa6fQ=a2|L7cp!KXcrZ90JOn%xJPbS>TmUWvj{uJZj{=Vdj{zItvEXsw@!$y{ z15X4w*aQP0pa4TK0$ZR2PXgOu3?^U-Dlh{Vff_Vm4i;bs?1DY84=x5z22TM`1y2J{ z2hRY{1kVD`2G0S{1%VfcJv;f%k(CfDeKXfe(X^fRBQYfscbvfKP%? zflq_afX{-@fzN|4fG>hCfiHuvfUknDfvmGB^=j1zZ)J1g-|I4z2;N39bdM z4Xy)DR$IcK*9F(Zo@k%zgBySwVqJ8cQ^1YX%JAon!A)$moAS4tft!O{fLnrFfm6Y0 z;MU+a;I`m&a0WON+zuQDN5D~V7Pvh)8{7fh5!?yf8QcZj72FNn9oz%l6Wj}&1MUs( z1I`8a1@{B@2j_tYfCqvHfd_;0!9&19!Nb7A!3E$#@Cfio@F?(T@EEWG9t$1^9uJ-X zGVnx@gH1310tzq$Bd`TZ@FcJe#$W=bpaL^+5vV}}=3oJKz%JMW``}{mWbhR5RPZ$L zbnp!DOzbeDDJBLhvH+V(=31Qt&eHa&QTF1$ZTR6?ipx4R|eh9e6!> z19&5N6L>Rt3wSGd8+bc-2Y4rV7kD>#4|p$lA9z3b0QeyI5cn|o2>2-Y82C8&1o$NQ z6!%E-vb- z$;Rb&aipsmGh<%soaHnd)I+Vgn)Ml1lGR01*u>Y2qH23F%jWs!xX^~zT`%1n9zI*E zZm?dpb0kZ(CtB?+$8XcdSGCWEP0ov23w*k!xM)#K2T;TauyJ*7)+^mvEt^%I}h*w+$gg?unCck+2zO|$7V8`Zg_rZ}y| zc%ExhCvPsf)@#9)eKg4@`gOzp*SwsRFiJL@9X-N-&OWQ_J<0d7d{zciGMpCB-}k5A zyZY$_;&aB*V9=dWrl)dPR%^6-)TAW{1x;z&;ovJNrz0Kmda*)fNVZig(Tb_iWWSwV zCbpStcvF|8QC-@8qi{SPj4NbjobMO)#!bEDWCrsWgY692E9a~iR^d>c!DBDe!B$mg zqhear`JBJ8q$_#vI-ib;jTpZkWfP<*$P!eqM@?0m48>q|o$uQ|vfWxIWH1-=ylDz= z*?yO!sn)wVP_N%R`FLTyLTMl~Q={0pnHvdWX^Vk&4L0GUftc)SrrbrvjeG+;6(mBH zz-GGtG9Q;08$ShMH}k>vZeC05aWikqK~}?S(|l5d^lj$z!B#e{YMcIAHt-&7F$S5{Mn>JGYehJAVDewW4VNBoqYUz<2wCT^ zwo>mtpbgX6K}5+`H7=TM(=0LI)?&Jy6>ytQ;xaz4GCrLlZ5x~W?a8}ZNt<(-9#rFT z0hNZcsZtE_&c}_8=;t}8CJ?Pa9-yTizO#>(mt#1UK4`7UM(n7XO51ZPrYTt9=C~Sc zr$trlZEwyQO!E<9?li`N>>Ng`n&i8tjiNas+_0>q7P|J>CU5WJsO{Zcb^ZCi?EXjB z`%}c$yZ(Hy_IC!*$(AaWwT!8|^|czm8?QH7G6!**jF&ZA<$YyiO#=Eii5jbsXo+Y3goVS_k?SlF^+m4o8#0Y*g;|M3Tl`xwBy$`%(qt z1m9UPpK%iNL57|hP1E;c(kyzaZ?K~uAq0}o^Psr?bzGSNLFJ593*+8|==UnOi-Mii zNMyJ%y*h-VtM%1MvAGy2tGTX)Qg9jmMrUtb^D7A#Gh=Eu0Cc&T3F-kgna=5!DJ`%V zqG8STQtrrh4lDIt%A~8CHFZ^*UFEbwJ7OuZs{062iQ-ZXM9*8t)EERDFBUv0CRME( zzD7%ZGuqB%QZxhfT{SI28A>tF2V3pO4wffm zR0{pbMmRtcUWb(VC@UtLg>2l|my1tSmzxXNbWNkUB~Y))TVTf#u6Z%q4;|T7;Vw04 z1yxTkO0_$kRg1bANpPOb_`6aocB2|>VLjYn$UfR>wv|m5<9V5JKxQjlKxBWamU6=- zny-8kwzHWdEb5G@vQEdlKUK4uO5^B*bG#0BaBjA5wwQ0+!st<`jzc)dg}DGQ#)?vO z+qKfn+7~L}O^C)ygHgT6rig}n@hg?)XK3w#W8)^9+on7BGh68nyDjRPjVoCXoM3kg z)vvX#6biPqmmNK*Eb~_uEF{Wf#b@E3LL>NEuGV zH0;O5ZG6vZ^tzZY>QqTP7-y%b)p=G8=k~CT@pO~!opeR+#kuW>Qd%u_^d|1E75hdi zKayb%;d;uAj!qM+X!Rk*C)lco=5@Pd*uFYFj{VJN#C-d+BBX9usN<}Gc393_gM+ZB zUb&b~ih2aieED9era4xjfqg#h4!Bp0GSC3h#38OUOEpvO!W|_VqR-3GqFOYW#Y7#mr4}8DjD4Lujg>i5HQ9&Ly0kuUG`Q@+{oY|QFUP_CT|pCm<`4%x5rbtE zyH&jn7hvbrri6$L8}omHwDa0T+_`GF;8>E(MFTB@25v5giz$(NHO0?TV}2yrx3%)l zqdU0g(aY%BcOJtF7^;ZotRkH48Bs;rq*f)vPiUhx9v@q$x)?9mK8Xs;Iq8bM%}t!= z=`ES6HBU)(ISKabc2b9-$z+pg?;`U~lolZ-sV)>5B9ubeV3I|n2@W+%baH=_0kteO zY)8_l|Rr&qfhZNM6A51Jq$f0o*~T8fcrqsfWLJ?X^lW@8AihH)7y@^fzL z(+H8uAsWmg8|P~uO?N6A*^&AM@nw1Y;K%{CFqX5j&p%U@fY zf~#H&C;eIr3%}f1>QWk=iyO|;NM&%W-ky$vl3O*O1pW5byb)OD<@U2Ep&)tXfe z1%vr471QJs{NF4wd5Cqd%0l2Y2+ndFTDn>r)xENeq=GB#U<_AL`k_Lvt9Hp#4WeH@ z?iNPFa}ZkQ;<|;;Xklir*5DEW zGff+5F3hqgA;-O9A)*)zgHXNlVNyNq{Ow4ie0o43E|)WDZx9W5P+I0W5=0V8yLnw| z(i$KVl1qoWzNl79gqrP*Gn;Rk!+R^!3hNb_X2Rf>#2+mM+ii8e;i7@8;VAS*s*7@p zmp!lq`1($k)~?yr&bDrxvAQaVLQ)Tj+bzk!rbT6B7W46Ze|>O^)HMc!F;;it*14Px zYG0YVASF4SKAH|0cUs|hL!!;aP(Pb-HE$B3wSlF>CnSr;Sy_eo zYV&rO#9-8^NiI>EIjWu?C{v+`o^lF};;^E!sfKg)dyOPZVR{%(Mj`%bX2l>p5OWbU zOP%t6?nA5=tF*fFSx)iZ%-n?Zm8>R5I+RnBa8wITO^~_6%d@0Z@O8Z&r4G@Drbk^} zmg`N-oVj)=Nk^DN4Bb;nBz`Gz4Y57efI*XEdOEH{wrk7=U%rL$;dZxvq^6a#W>Qt2 z%lA&>9DCOeYq_MhUtFE2p762>(-IBP+>J=Dx!sy{^{!TW^#qepcWp$qxo)PNGGme+ zQ%&J$P)v?{i=5jSo{8~17uqFO{Tw4jx|9#~v(1e%WjaWa37!$G3Di|R$^FcVi%yabZP4ap~ zK#AGMx(CxaBjPxy;o=Ml6JmLk5rFB?Zo1;zr*6B|G+Z*##dgbX^J|{iQr)p%G51xld_>;C*R~fh+0h#k~REvUYQYp_e5{ zTt(#P0X+&c*j2Az_40Q~!hDD1SS;``Wh@pq`ktm5UBD_4hod%`o|V672e_@GO-#=1!hh0+ynxMk{paUH@LrQC*S z+T0n#=Rc?k+2X}C{E_u~6I*uH=ykDP%&Q0=F=uDJG2T-+5#CKNtl$~-{zKob_nAx? zdE@lU7?GuMo>0YxPAQ-bcPrKhvr#<3hX+zlH;yik+cdelVx&~R-%{P(($OacsfzY& zd^cS@)~>>Kxsr(N`E%hnF`zY!+XPp7>OJ6(m1DgRaE_dndmLi3g_AZeg6n71cz=YG zpxO0y_p9@hfyj=J0QD1d{z288g-4Je?2>q84ZBSRUSV>jpcQPY9|Q^O6D zHF~&G(>#A^L9-_OJsIHwNDc8l4c(dX&(7aSjnGh?-C4RyNG&QdM@4<9Ylw7n7gY3n zJY+{w_A6xzXRwgQMsd;XsyS7x+(fT-$ebcd8;AU=??r_-BWV4)F;1|AISEnAzxY#jnLC=RS z?&@dyUcd}(I266A1;%l9zWBZBH0kT$^sO!_CCcrVx1k%Y%L ziz2qKNr^arKWr^9569^oNejCLWXaB>T8Lt@T!vGb5~4ci2>#wVJW=G<)34=HQSY|E z#eVg!IKDfE^#h9H73bN#Nz=Yi#l?%=SwVUW0r-g&q_LwrY*oCq?1wZ53nALYEP(xh1qqR4wx8%A|rx2<}XOv}#P7#F56` zqL#RrV;{|U2lcqMK>hBt&HYoG$BK9dK-{gFB+Z!k`S-mK|yHoS%i*29z_`&#R$40(jFJ0jxr{oZ<2=y}4)*FM?Fp68^z8An%Wya%&>*RGACeb8dT09{kX>$c5RI)y<8y z=7ruUk`J2`s3I(he^MafG?q)N>P|6i;MT z717HL)Iy^q=dqZ3hI~fuU3I5=BD8i=qn6d>ue#6DDD7H!wjlN#O$lhST8dvGji}RS4>>O?Nj%$$ad33iZFJmNMg}!m6GG2Skw=~)hs#}@YVq31sPR4kvOZ>ZML33`xyHeN2ls4PEL+<+ZuKmq|t^6LP8^A*rYqMGA zyae>Si>`sIK5w9TM^4uUKh@aJ%KNXNxP(7q0?@Ve%*{b9R~oveIj~_jntK~dLG zH&@WdxA7paKgQsFL#tYRjnr#&30ucws5HDiU2^8&S^=M*|LRRb_)o2Tws#Qj1Zbi> zOsyJmI|?4OxXJAO4xn>wUCY~>^sdkiQ>+CmddI|8%)3;#;tV@2qTd0->7(tHF70qy zw=z{boYr;2bu3SDcS2nQcT{!rOAS_L?nWn>*L8MU^`4n^Req)`($eed{J&biGVS^x z`XJ@|O``WCUH$qbTE8+$D`Z{2R+?6tx~5!ljk<3Bv}L6lN4RU@8(MX`x_bAQSL+WL zB7{MezzQR@!lc5Y^(%YuGwb#5W_SLIDtn^T)j1>e>sIDnZ?H14FKjge=w@&`qZg1} zne$x;ZddFJjpCJQ+!jaO66HLyR1R2KdUUb3&`PmxqSYGqzjQZkmuN<1)KX8;`2r(5 zzB1cV-OA#9ztOK-hOZp`i$hEHpsVo1th)#9U6u22@(Q|gSCpxIW&ZF=xUcLwL_kj2 zN@JIXPN}TybG3S9DEFI%GPP@gDDRqkm=eXEDTqW> z*Y1q|u+gqtW%y*hqjuvI<9gODdUazuDWXk!(9%_Vyyqqd6OQ3=ZHTnh`ADn%b+Iu0MR$KlN^l+;1p1dz#9fQ(dO_FFHF01_$0PZ8(iyY;0{8$t|HG z+k^L&*57lxCRp3u_OF*CT`3C_*Glsdw=j`}auPAGm7R5My|O-?OEtAuT*@ zQ4?wD*PzxlV!HMBHFl~I7dc(${El(S+e)6zJTu`lCjr8z_Y}H4_Yk}O{%gJ2c){z4 zxct;%I~KZ3@5u9r_T&K)xD?n0Ui<1Powz?Jv;zoEsrcn*L(Zr3d3>~Bi6C@G(D+$p zxa?luzEI{WZa2bu=Lcxgx73~M%{rY#v{P2R%w5v3b>AbipN=V;j#I*Bx}(8t#3!?I zR%o^70jjKhVP?Xe1{T2I-Nv(j>J znX*_~;&!UNiRtz3IqplneJ)qU7?t5O=?@R+#ro~yqc>JOjgIC% zeycjIohyhav*f51`_@9{t;A#B-t!)^}v3oN5N86~y<(KHrdC7$o)H5osRV40Ch zL*cZ)5NdUYKY5B@0e8<+{h6TsNJM{lXTMJ9^q^DNYV`*v4Cd|Q{o3-aPc=LLv|}zk z`GO1Yb>VsU&K~@b`<$Cyct&tA}TpDQiz&M=C1aP+5V-+-IYRKGv56q{ojsV t!r;dq^TU7tKU3gx$Ncc$A5!24p@2$@woiX{e@i7^|AUPA!?uY6{{!j7SQ-ET From 1433ba645649bbfe83be3790136342cf0afe12cf Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 06:24:27 +0000 Subject: [PATCH 26/29] Minors. --- .../operators/fused/fused_attention_op.cc | 22 ++++++------ .../operators/fused/fused_attention_op.cu | 8 ++--- .../operators/fused/fused_dropout_helper.h | 2 +- .../unittests/test_fused_attention_op.py | 2 +- python/paddle/nn/functional/__init__.py | 3 +- .../paddle/nn/functional/fused_transformer.py | 36 +++++++++---------- 6 files changed, 37 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 9247ca0334a44a..a286c39f7f8db5 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -215,13 +215,13 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { }); // for dropout in fmha. - AddAttr("attn_dropout_prob", "Probability of setting units to zero.") + AddAttr("attn_dropout_rate", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float &drop_p) { PADDLE_ENFORCE_EQ( drop_p >= 0.0f && drop_p <= 1.0f, true, platform::errors::InvalidArgument( - "'attn_dropout_prob' must be between 0.0 and 1.0.")); + "'attn_dropout_rate' must be between 0.0 and 1.0.")); }); AddAttr("attn_dropout_is_test", "(bool, default false) Set to true for inference only, false " @@ -240,14 +240,14 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "[\"downgrade_in_infer\"|\"upscale_in_train\"]" "There are two kinds of ways to implement dropout" "(the mask below is a tensor have the same shape with input" - "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "the value of mask is 0 or 1, the ratio of 0 is dropout_rate)" "1. downgrade_in_infer(default), downgrade the outcome at inference " "time" " train: out = input * mask" - " inference: out = input * (1.0 - dropout_prob)" + " inference: out = input * (1.0 - dropout_rate)" "2. upscale_in_train, upscale the outcome at training time, do nothing " "in inference" - " train: out = input * mask / ( 1.0 - dropout_prob )" + " train: out = input * mask / ( 1.0 - dropout_rate )" " inference: out = input" " dropout op can be removed from the program. the program will be " "efficient") @@ -260,12 +260,12 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "upscale_in_train")); }); - AddAttr("dropout_prob", "Probability of setting units to zero.") + AddAttr("dropout_rate", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float &drop_p) { PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, platform::errors::InvalidArgument( - "'dropout_prob' must be between 0.0 and 1.0.")); + "'dropout_rate' must be between 0.0 and 1.0.")); }); AddAttr("dropout_is_test", @@ -292,16 +292,16 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "dropout_implementation can only be downgrade_in_infer or " "upscale_in_train")); }); - AddAttr("ln2epsilon", + AddAttr("ln_epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) - .AddCustomChecker([](const float &ln2epsilon) { - PADDLE_ENFORCE_EQ(ln2epsilon >= 0.0f && ln2epsilon <= 0.001f, true, + .AddCustomChecker([](const float &ln_epsilon) { + PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, platform::errors::InvalidArgument( "'epsilon' of the second LayerNorm in Fused " "attention op should be between" "0.0 and 0.001, But received [%s].", - ln2epsilon)); + ln_epsilon)); }); AddComment(R"DOC( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index a8db5b25729e73..18a42b5c2cee29 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -75,9 +75,9 @@ class FusedAttentionOpKernel : public framework::OpKernel { ctx.Output("BiasDropoutResidualOut"); auto *ln_mean_2 = ctx.Output("Ln2Mean"); auto *ln_var_2 = ctx.Output("Ln2Variance"); - const float ln2epsilon = ctx.Attr("ln2epsilon"); + const float ln_epsilon = ctx.Attr("ln_epsilon"); - float attn_dropout_prob = ctx.Attr("attn_dropout_prob"); + float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); bool is_test_1 = ctx.Attr("attn_dropout_is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); @@ -156,7 +156,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { bsz_seq, output_size, input_size, true); AttnDropoutParam attn_dropout_param( - is_test_1, dropout_implementation_1, attn_dropout_prob, + is_test_1, dropout_implementation_1, attn_dropout_rate, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, @@ -170,7 +170,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, - ln2epsilon); + ln_epsilon); if (pre_layer_norm) { layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index fcfa405a52f9b1..33fde64164d129 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -66,7 +66,7 @@ struct DropoutParam { } else { pre_fix = pre_fix + "_"; } - dropout_prob = context.Attr(pre_fix + "prob"); + dropout_prob = context.Attr(pre_fix + "rate"); auto& dropout_implementation = context.Attr(pre_fix + "implementation"); is_upscale_in_train = (dropout_implementation == "upscale_in_train"); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 1d0bd46e7c7988..a5578d71c5cd06 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -190,7 +190,7 @@ def GetFusedAttentionOut(self): if attn_mask is not None: attn_mask = _convert_attention_mask(attn_mask, x.dtype) - final_out = F.fused_multihead_attention( + final_out = F.fused_multi_head_attention( x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 830cf5dbc489b6..9ab1cf609d7234 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,7 +60,7 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 -from .fused_transformer import fused_multihead_attention # noqa: F401 +from .fused_transformer import fused_multi_head_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 @@ -208,4 +208,5 @@ 'layer_norm', 'instance_norm', 'class_center_sample', + 'fused_multi_head_attention', ] diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index e569dd02f8fe8e..401c2b8783af87 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -19,22 +19,22 @@ __all__ = [] -def fused_multihead_attention(x, - qkv_weight, - linear_weight, - pre_layer_norm=False, - pre_ln_scale=None, - pre_ln_bias=None, - ln_scale=None, - ln_bias=None, - pre_ln_epsilon=1e-05, - qkv_bias=None, - linear_bias=None, - attn_mask=None, - dropout_rate=0.5, - attn_dropout_rate=0.5, - ln_epsilon=1e-05, - name=None): +def fused_multi_head_attention(x, + qkv_weight, + linear_weight, + pre_layer_norm=False, + pre_ln_scale=None, + pre_ln_bias=None, + ln_scale=None, + ln_bias=None, + pre_ln_epsilon=1e-05, + qkv_bias=None, + linear_bias=None, + attn_mask=None, + dropout_rate=0.5, + attn_dropout_rate=0.5, + ln_epsilon=1e-05, + name=None): r""" """ if in_dygraph_mode(): @@ -44,7 +44,7 @@ def fused_multihead_attention(x, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', - pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_prob', - dropout_rate, 'attn_dropout_prob', attn_dropout_rate, 'ln2epsilon', + pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', + dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', ln_epsilon) return final_out From cf7be139f3f15055f6b4a836366ed0b1c73b7619 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 14 Oct 2021 08:48:08 +0000 Subject: [PATCH 27/29] Add english doc for functional.fused_multi_head_attention --- .../paddle/nn/functional/fused_transformer.py | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index 401c2b8783af87..605ef54a8c0457 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -35,7 +35,83 @@ def fused_multi_head_attention(x, attn_dropout_rate=0.5, ln_epsilon=1e-05, name=None): - r""" + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. This API only + support self_attention. The pseudo code is as follows: + if pre_layer_norm: + out = layer_norm(x); + out = linear(out) + qkv)bias + else: + out = linear(x) + bias; + out = transpose(out, perm=[2, 0, 3, 1, 4]); + # extract q, k and v from out. + q = out[0:1,::] + k = out[1:2,::] + v = out[2:3,::] + out = q * k^t; + out = attn_mask + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + out = out_linear(out); + out = layer_norm(x + dropout(linear_bias + out)); + + Parameters: + x (Tensor): The input tensor of fused_multi_head_attention. The shape is + `[batch\_size, sequence\_len, embed\_dim]`. + qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. + linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture. + Default False. + pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. + pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. + ln_scale (Tensor, optional): The weight tensor of layernorm. Default None. + ln_bias (Tensor, optional): The bias tensor of layernorm. Default None. + pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm + to avoid dividing by zero. Default is 1e-5. + qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. + Default None. + linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. + attn_mask (Tensor, optional): + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0. + attn_dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout in attention. + 0 for no dropout. Default 0. + ln_epsilon (float, optional): Small float value added to denominator of layer_norm + to avoid dividing by zero. Default is 1e-5. + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand(shape=(2, 4, 128), dtype="float32") + # qkv_weight: [3, num_head, dim_head, dim_embed] + qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + # qkv_bias: [3, num_head, dim_head] + qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + # linear_weight: [embed_dim, embed_dim] + linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + # linear_bias: [embed_dim] + linear_bias = paddle.rand(shape=[128], dtype="float32") + # self attention mask: [batch_size, num_heads, seq_len, seq_len] + attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") + + # output: [batch_size, seq_len, embed_dim] + output = F.fused_multi_head_attention( + x, qkv_weight, linear_weight, False, + None, None, None, None, 1e-5, qkv_bias, + linear_bias, attn_mask) + # [2, 4, 128] + print(output) """ if in_dygraph_mode(): # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, From 10687a6512722d4f02e44db9f180c72ebe4deddf Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 21 Oct 2021 06:01:22 +0000 Subject: [PATCH 28/29] Add "#require gpu" for sample code in english doc. --- python/paddle/nn/functional/fused_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index 605ef54a8c0457..078b7a69513345 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -88,7 +88,8 @@ def fused_multi_head_attention(x, Examples: .. code-block:: python - + + # required: gpu import paddle import paddle.nn.functional as F From 0f93775d57e33c68ae1ed09040b02c553f601700 Mon Sep 17 00:00:00 2001 From: limin2021 <1121099234@qq.com> Date: Thu, 21 Oct 2021 11:08:15 +0000 Subject: [PATCH 29/29] Improve format of sample code. --- python/paddle/nn/functional/fused_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py index 078b7a69513345..565ef223a96cbb 100644 --- a/python/paddle/nn/functional/fused_transformer.py +++ b/python/paddle/nn/functional/fused_transformer.py @@ -44,7 +44,7 @@ def fused_multi_head_attention(x, out = layer_norm(x); out = linear(out) + qkv)bias else: - out = linear(x) + bias; + out = linear(x) + bias; out = transpose(out, perm=[2, 0, 3, 1, 4]); # extract q, k and v from out. q = out[0:1,::] @@ -56,8 +56,8 @@ def fused_multi_head_attention(x, out = dropout(out); out = out * v; out = transpose(out, perm=[0, 2, 1, 3]); - out = out_linear(out); - out = layer_norm(x + dropout(linear_bias + out)); + out = out_linear(out); + out = layer_norm(x + dropout(linear_bias + out)); Parameters: x (Tensor): The input tensor of fused_multi_head_attention. The shape is @@ -112,7 +112,7 @@ def fused_multi_head_attention(x, None, None, None, None, 1e-5, qkv_bias, linear_bias, attn_mask) # [2, 4, 128] - print(output) + print(output.shape) """ if in_dygraph_mode(): # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,