Skip to content

Commit 0819d0e

Browse files
committed
Fused attention op forward (#35905)
功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
1 parent 6840cf5 commit 0819d0e

File tree

12 files changed

+1261
-2
lines changed

12 files changed

+1261
-2
lines changed

cmake/operators.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ function(op_library TARGET)
216216
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
217217
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
218218
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
219-
"fused_bn_add_activation_op")
219+
"fused_bn_add_activation_op" "fused_attention_op")
220220
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
221221
set(pybind_flag 1)
222222
endif()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/generator.h"
18+
#include "paddle/fluid/framework/tensor_util.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
24+
const framework::Tensor* seed,
25+
const bool is_fix_seed, const int seed_val,
26+
const int offset, uint64_t* seed_data,
27+
uint64_t* increment) {
28+
int device_id =
29+
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
30+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
31+
32+
if (seed) {
33+
framework::Tensor seed_cpu_tensor;
34+
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
35+
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
36+
*increment = offset;
37+
} else if (seed && platform::is_cpu_place(seed->place())) {
38+
*seed_data = *(seed->data<int>());
39+
*increment = offset;
40+
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
41+
auto seed_offset = gen_cuda->IncrementOffset(offset);
42+
*seed_data = seed_offset.first;
43+
*increment = seed_offset.second;
44+
} else {
45+
std::random_device rnd;
46+
*seed_data = is_fix_seed ? seed_val : rnd();
47+
*increment = offset;
48+
}
49+
}
50+
51+
} // namespace operators
52+
} // namespace paddle

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ register_operators(EXCLUDES
1616
fusion_gru_op
1717
fusion_lstm_op
1818
fused_bn_add_activation_op
19+
fused_attention_op
1920
fused_transformer_op)
2021

2122
# fusion_gru_op does not have CUDA kernel
@@ -77,5 +78,8 @@ if (WITH_GPU OR WITH_ROCM)
7778
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
7879
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
7980
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
81+
# fused_attention_op
82+
op_library(fused_attention_op)
83+
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
8084
endif()
8185
endif()

paddle/fluid/operators/fused/fused_attention_op.cc

Lines changed: 336 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <cuda_fp16.h>
16+
#include <cub/cub.cuh>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/operator.h"
19+
#include "paddle/fluid/platform/cuda_device_function.h"
20+
#include "paddle/fluid/platform/cudnn_helper.h"
21+
22+
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
23+
#include "paddle/fluid/operators/math/math_function.h"
24+
25+
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
26+
#include "paddle/fluid/operators/fused/attn_gemm.h"
27+
#include "paddle/fluid/operators/fused/fmha_ref.h"
28+
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
29+
30+
namespace paddle {
31+
namespace operators {
32+
33+
using Tensor = framework::Tensor;
34+
35+
template <typename T>
36+
class FusedAttentionOpKernel : public framework::OpKernel<T> {
37+
public:
38+
void Compute(const framework::ExecutionContext &ctx) const override {
39+
using U = LayerNormParamType<T>;
40+
auto *input_x = ctx.Input<Tensor>("X");
41+
42+
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
43+
const float epsilon = ctx.Attr<float>("epsilon");
44+
auto *ln_scale = ctx.Input<Tensor>("LnScale");
45+
auto *ln_bias = ctx.Input<Tensor>("LnBias");
46+
auto *ln_mean = ctx.Output<Tensor>("LnMean");
47+
auto *ln_var = ctx.Output<Tensor>("LnVariance");
48+
auto *ln_out = ctx.Output<Tensor>("LnOut");
49+
50+
// x: qkv's input [batch_size, seq_len, dim_embed]
51+
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
52+
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
53+
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
54+
auto *qkv_out = ctx.Output<Tensor>("QKVOut");
55+
auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");
56+
57+
auto *src_mask = ctx.Input<Tensor>("SrcMask");
58+
auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
59+
auto *qk_out = ctx.Output<Tensor>("QKOut");
60+
auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
61+
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
62+
auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
63+
auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
64+
auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
65+
auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
66+
67+
auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
68+
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
69+
auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");
70+
71+
auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
72+
auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
73+
auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
74+
auto *bias_dropout_residual_out =
75+
ctx.Output<Tensor>("BiasDropoutResidualOut");
76+
auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
77+
auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
78+
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
79+
80+
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
81+
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
82+
auto &dropout_implementation_1 =
83+
ctx.Attr<std::string>("attn_dropout_implementation");
84+
bool is_upscale_in_train_1 =
85+
(dropout_implementation_1 == "upscale_in_train");
86+
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
87+
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
88+
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
89+
90+
// final output.
91+
auto *out = ctx.Output<Tensor>("Y");
92+
93+
// get data ptr for qkv part.
94+
const auto input_x_dims = input_x->dims();
95+
const auto qkv_w_dims = qkv_weight->dims();
96+
97+
auto *x_data = input_x->data<T>();
98+
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
99+
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
100+
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
101+
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
102+
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
103+
104+
auto *qkv_weight_data = qkv_weight->data<T>();
105+
auto *qkv_bias_data = qkv_bias->data<T>();
106+
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
107+
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());
108+
109+
// get data ptr for FMHA.
110+
auto *transpose_out_2_data =
111+
transpose_out_2->mutable_data<T>(ctx.GetPlace());
112+
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
113+
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
114+
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
115+
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
116+
auto *attn_dropout_mask_out_data =
117+
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
118+
auto *attn_dropout_out_data =
119+
attn_dropout_out->mutable_data<T>(ctx.GetPlace());
120+
auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());
121+
122+
// get data ptr for out_linear.
123+
auto *out_linear_weight_data = out_linear_weight->data<T>();
124+
auto *out_linear_bias_data = out_linear_bias->data<T>();
125+
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
126+
127+
// get data ptr for bias+dropout+residual+layernorm
128+
auto *ln_scale_2_data =
129+
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
130+
auto *ln_bias_2_data =
131+
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
132+
auto *dropout_mask_out_data =
133+
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
134+
auto *bias_dropout_residual_out_data =
135+
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
136+
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
137+
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
138+
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
139+
140+
int batch_size = input_x_dims[0];
141+
int max_seq_len = input_x_dims[1];
142+
int dim_embed = input_x_dims[2];
143+
144+
int num_head = qkv_w_dims[1];
145+
int dim_head = qkv_w_dims[2];
146+
147+
int bsz_seq = batch_size * max_seq_len;
148+
int hidden_size = num_head * dim_head;
149+
int output_size = 3 * hidden_size;
150+
int input_size = dim_embed;
151+
152+
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
153+
epsilon, bsz_seq, dim_embed);
154+
// (transA, transB, compute_bias) = (false, true, true)
155+
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, true,
156+
bsz_seq, output_size, input_size, true);
157+
158+
AttnDropoutParam attn_dropout_param(
159+
is_test_1, dropout_implementation_1, attn_dropout_rate,
160+
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
161+
auto fmha_ref_compute =
162+
FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
163+
dim_head, attn_dropout_param);
164+
165+
output_size = hidden_size;
166+
// (transA, transB, compute_bias) = (false, false, false)
167+
auto out_linear_compute =
168+
AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
169+
output_size, input_size, false);
170+
DropoutParam dropout_param2(ctx, 0);
171+
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
172+
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
173+
ln_epsilon);
174+
175+
if (pre_layer_norm) {
176+
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
177+
ln_out_data, ln_mean_data, ln_var_data);
178+
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
179+
qkv_out_data, qkv_bias_out_data);
180+
} else {
181+
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
182+
qkv_out_data, qkv_bias_out_data);
183+
}
184+
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
185+
qk_out, src_mask_out, softmax_out,
186+
attn_dropout_mask_out, attn_dropout_out,
187+
qktv_out, fmha_out);
188+
// fmha_out: [batch_size, seq_len, num_head, head_dim]
189+
// weight: [embed_dim, embed_dim]
190+
// out_linear_out: [batch_size, seq_len, embed_dim]
191+
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
192+
nullptr, out_linear_out_data, nullptr);
193+
// output = layernorm(residual + dropout(input + bias))
194+
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
195+
ctx.cuda_device_context(), out_linear_out_data, x_data,
196+
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
197+
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
198+
ln_mean_2_data, ln_var_2_data);
199+
}
200+
};
201+
202+
} // namespace operators
203+
} // namespace paddle
204+
205+
namespace ops = paddle::operators;
206+
namespace plat = paddle::platform;
207+
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
208+
ops::FusedAttentionOpKernel<double>,
209+
ops::FusedAttentionOpKernel<plat::float16>);

0 commit comments

Comments
 (0)