Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f5eee9f
Add fused_attention_op: add impl wrappers.
limin2021 Sep 22, 2021
e16e3b3
Add fused_attention_op: forward.
limin2021 Sep 22, 2021
42f0372
Add fused_attention_op: forward impl.
limin2021 Sep 22, 2021
c6aebef
Remove useless code.
limin2021 Sep 22, 2021
2c0ab6c
Remove useless code.
limin2021 Sep 22, 2021
ece3c08
Remove docs.
limin2021 Sep 22, 2021
b18b405
Minors.
limin2021 Sep 22, 2021
b939159
Minors.
limin2021 Sep 23, 2021
07fd753
Update test_fused_attention_op.py
limin2021 Sep 23, 2021
ef89a94
Merge branch 'PaddlePaddle:develop' into fused_attention_op_2_fw
limin2021 Sep 23, 2021
b44d882
Remove static construction of python api.
limin2021 Sep 23, 2021
ff3df46
Modifications accordding to reviews.
limin2021 Sep 23, 2021
8a4c2a8
Modifications accordding to Xreki's review.
limin2021 Sep 26, 2021
739d9ca
Modifications unittest/cmakefile.txt.
limin2021 Sep 27, 2021
1d9e125
Fetch new fused_dropout_helper.h from #35843.
limin2021 Sep 27, 2021
4dd4260
Remove include fused_attention_op.h.
limin2021 Sep 27, 2021
2e3f4f2
Polish names of variants.
limin2021 Sep 27, 2021
f17c444
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Sep 29, 2021
13d4ff3
Revert "Polish names of variants."
limin2021 Oct 13, 2021
300ec35
Revert "Modifications accordding to Xreki's review."
limin2021 Oct 13, 2021
7b28f7c
Move fused_multi_head_attention from common.py.
limin2021 Oct 13, 2021
30fef54
Modify copyright and names with number.
limin2021 Oct 14, 2021
766ef85
Remove HIP and use OpTest and remove print.
limin2021 Oct 14, 2021
0bc03a6
Minors.
limin2021 Oct 14, 2021
99e36f9
Polish functional.fused_attention_op.
limin2021 Oct 14, 2021
2d9f727
Minors.
limin2021 Oct 14, 2021
f35b3c7
Remove commits of tools/__pycache__/.
limin2021 Oct 14, 2021
1433ba6
Minors.
limin2021 Oct 14, 2021
cf7be13
Add english doc for functional.fused_multi_head_attention
limin2021 Oct 14, 2021
ae875ca
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Oct 14, 2021
10687a6
Add "#require gpu" for sample code in english doc.
limin2021 Oct 21, 2021
0f93775
Improve format of sample code.
limin2021 Oct 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op")
"fused_bn_add_activation_op" "fused_attention_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ register_operators(EXCLUDES
fusion_gru_op
fusion_lstm_op
fused_bn_add_activation_op
fused_attention_op
fused_transformer_op)

# fusion_gru_op does not have CUDA kernel
Expand Down Expand Up @@ -77,6 +78,9 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
# fused_attention_op
op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
endif()
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
Expand Down
344 changes: 344 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cc

Large diffs are not rendered by default.

227 changes: 227 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cmake里面加了限制,不支持ROCM。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cuda_device_function.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif

#include <cuda_fp16.h>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/math/math_function.h"

#include "paddle/fluid/operators/fused/fused_attention_op.h"

#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<Tensor>("X");

const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_bias = ctx.Input<Tensor>("LnBias");
auto *ln_mean = ctx.Output<Tensor>("LnMean");
auto *ln_var = ctx.Output<Tensor>("LnVariance");
auto *ln_out = ctx.Output<Tensor>("LnOut");

// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
auto *qkv_out = ctx.Output<Tensor>("QKVOut");
auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");

auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
auto *qk_out = ctx.Output<Tensor>("QKOut");
auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut");

auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");

auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Output<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
const float ln2epsilon = ctx.Attr<float>("ln2epsilon");

float attn_dropout_prob = ctx.Attr<float>("attn_dropout_prob");
bool is_test_1 = ctx.Attr<bool>("is_test1");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("dropout_implementation1");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("fix_seed1");
int seed_val_1 = ctx.Attr<int>("seed1");

// final output.
auto *out = ctx.Output<Tensor>("Y");

// get data ptr for qkv part.
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();

auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());

auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for FMHA.
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量好像没有使用的地方?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto *transpose_out_2_data =
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *attn_dropout_out_data =
attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());

int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];

int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];

int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;

bool transA = false;
bool transB = true;
bool compute_bias = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这3个变量被用了2次,但是第2次用之前重新赋值了,我觉得没有必要定义。

Copy link
Contributor Author

@limin2021 limin2021 Sep 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);

AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
auto fmha_ref_compute =
FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
dim_head, attn_dropout_param);

output_size = hidden_size;
transA = false;
transB = false;
compute_bias = false;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);

if (pre_layer_norm) {
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
ops::FusedAttentionOpKernel<double>,
ops::FusedAttentionOpKernel<plat::float16>);
16 changes: 16 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件可以不需要吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是需要的。pre-commit要求op.cc和op.cu必须include op.h,否则过不了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该不是强制要求必须有.h的,multi_dot就只有multi_dot_op.cc。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Loading