Skip to content

Commit 5e23e64

Browse files
authored
[XPU] Add multi-head self/cross attention fused ops. (#10037)
1 parent 082b78e commit 5e23e64

19 files changed

+1578
-21
lines changed

lite/api/paddle_use_passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ USE_MIR_PASS(assign_value_calc_offline_pass);
8080
USE_MIR_PASS(__xpu__graph_dedup_pass);
8181
USE_MIR_PASS(__xpu__resnet_fuse_pass);
8282
USE_MIR_PASS(__xpu__gn_silu_fuse_pass);
83+
USE_MIR_PASS(__xpu__multihead_cross_attn_fuse_pass);
84+
USE_MIR_PASS(__xpu__multihead_self_attn_fuse_pass);
8385
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
8486
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
8587
USE_MIR_PASS(__xpu__fc_fuse_pass);

lite/core/optimizer/mir/fusion/__xpu__multihead_cross_attn_fuse_pass.cc

Lines changed: 435 additions & 0 deletions
Large diffs are not rendered by default.

lite/core/optimizer/mir/fusion/__xpu__multihead_self_attn_fuse_pass.cc

Lines changed: 427 additions & 0 deletions
Large diffs are not rendered by default.

lite/core/optimizer/optimizer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
202202
"__xpu__mmdnn_fuse_pass",
203203
"__xpu__bigru_fuse_pass",
204204
"__xpu__roformer_relative_pos_fuse_pass",
205+
"__xpu__multihead_self_attn_fuse_pass",
206+
"__xpu__multihead_cross_attn_fuse_pass",
205207
"__xpu__quick_gelu_fuse_pass",
206208
"__xpu__gn_silu_fuse_pass",
207209
"__xpu__multi_encoder_fuse_pass",

lite/kernels/xpu/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ add_kernel(__xpu__bigru_compute_xpu XPU extra SRCS __xpu__bigru_compute.cc)
133133
add_kernel(__xpu__dynamic_lstm_compute_xpu XPU extra SRCS __xpu__dynamic_lstm_compute.cc)
134134
add_kernel(__xpu__multi_softmax_compute_xpu XPU extra SRCS __xpu__multi_softmax_compute.cc)
135135
add_kernel(__xpu__gn_silu_compute_xpu XPU extra SRCS __xpu__gn_silu_compute.cc)
136+
add_kernel(__xpu__multihead_self_attn_compute_xpu XPU extra SRCS __xpu__multihead_self_attn_compute.cc)
137+
add_kernel(__xpu__multihead_cross_attn_compute_xpu XPU extra SRCS __xpu__multihead_cross_attn_compute.cc)
138+
136139
if(XPU_WITH_XFT)
137140
add_kernel(fusion_decoding_compute_xpu XPU extra SRCS fusion_decoding_compute.cc)
138141
add_kernel(fusion_unified_decoding_compute_xpu XPU extra SRCS fusion_unified_decoding_compute.cc)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/__xpu__multihead_cross_attn_compute.h"
16+
#include <vector>
17+
#include "lite/backends/xpu/xpu_header_sitter.h"
18+
#include "lite/core/op_registry.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace xpu {
24+
25+
template <typename T>
26+
static std::vector<const T*> prepare_weight(
27+
const std::vector<lite::Tensor*>& fc_weight) {
28+
std::vector<const T*> result;
29+
for (auto* weight : fc_weight) {
30+
result.push_back(reinterpret_cast<const T*>(weight->data<float>()));
31+
}
32+
return result;
33+
}
34+
35+
template <typename InType, PrecisionType PType>
36+
void XPUMhcaCompute<InType, PType>::PrepareWeightMax(
37+
const std::vector<lite::Tensor*>& weight_max,
38+
int max_ptr_len,
39+
std::vector<const float*>* max_xpu_ptrs) {
40+
int max_value_num = 0;
41+
for (auto max_tensor : weight_max) {
42+
max_value_num += max_tensor->numel();
43+
}
44+
VLOG(3) << "Total weight max value number: " << max_value_num;
45+
weight_max_guard_ =
46+
TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float));
47+
float* weight_max_ptr = reinterpret_cast<float*>(weight_max_guard_->addr_);
48+
49+
int offset = 0;
50+
for (auto max_tensor : weight_max) {
51+
float* cur_weight_max_ptr = weight_max_ptr + offset;
52+
auto len = max_tensor->numel();
53+
VLOG(6) << "weight max value: " << max_tensor->data<float>()[0] << " "
54+
<< max_tensor->data<float>()[len - 1];
55+
std::vector<float> cpu_max(max_ptr_len, max_tensor->data<float>()[0]);
56+
lite::TargetWrapperXPU::MemcpySync(cur_weight_max_ptr,
57+
cpu_max.data(),
58+
sizeof(float) * max_ptr_len,
59+
IoDirection::HtoD);
60+
max_xpu_ptrs->push_back(cur_weight_max_ptr);
61+
offset += max_ptr_len;
62+
}
63+
}
64+
65+
template <typename InType, PrecisionType PType>
66+
void XPUMhcaCompute<InType, PType>::PrepareForRun() {
67+
auto& ctx = this->ctx_->template As<XPUContext>();
68+
auto& param = this->template Param<param_t>();
69+
// prepare bias
70+
for (auto* fc_bias : param.fc_bias) {
71+
arg_fc_bias_.push_back(fc_bias->template data<float>());
72+
}
73+
// prepare scale
74+
for (auto* ln_scale : param.ln_scale) {
75+
arg_ln_scale_.push_back(ln_scale->template data<float>());
76+
}
77+
// prepare ln_bias
78+
for (auto* ln_bias : param.ln_bias) {
79+
arg_ln_bias_.push_back(ln_bias->template data<float>());
80+
}
81+
arg_fc_weight_int16_ = prepare_weight<int16_t>(param.fc_weight);
82+
const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size();
83+
PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_);
84+
}
85+
86+
template <typename InType, PrecisionType PType>
87+
void XPUMhcaCompute<InType, PType>::Run() {
88+
// TODO(shenyijun): The compute of this op will be adapted to XFT interface
89+
// later on.
90+
//
91+
// auto& param = this->template Param<param_t>();
92+
// auto& ctx = this->ctx_->template As<XPUContext>();
93+
// const InType* in = param.input->template data<InType>();
94+
// const InType* embedding = param.embedding->template data<InType>();
95+
// InType* out = param.output->template mutable_data<InType>(TARGET(kXPU));
96+
// int batch = static_cast<int>(param.input->dims()[0]);
97+
// int seqlen = static_cast<int>(param.input->dims()[1]);
98+
// int embedding_seq = static_cast<int>(param.embedding->dims()[1]);
99+
// int r = xdnn::unet_mhca_fusion<InType, int16_t, InType, int16_t>(
100+
// ctx.GetRawContext(),
101+
// in,
102+
// embedding,
103+
// *(XPUMhcaCompute::GetWeight<int16_t>()),
104+
// out,
105+
// arg_fc_bias_,
106+
// arg_ln_scale_,
107+
// arg_ln_bias_,
108+
// fc_weight_max_,
109+
// batch,
110+
// param.head_num,
111+
// param.size_per_head,
112+
// seqlen,
113+
// param.hidden_dim,
114+
// embedding_seq,
115+
// param.embedding_dim);
116+
// CHECK_EQ(r, 0);
117+
}
118+
119+
} // namespace xpu
120+
} // namespace kernels
121+
} // namespace lite
122+
} // namespace paddle
123+
124+
namespace xpu = paddle::lite::kernels::xpu;
125+
126+
using XPUMhca_FP32 = xpu::XPUMhcaCompute<float, PRECISION(kFloat)>;
127+
using XPUMhca_FP16 = xpu::XPUMhcaCompute<float16, PRECISION(kFP16)>;
128+
129+
REGISTER_LITE_KERNEL(
130+
__xpu__multihead_cross_attn, kXPU, kFloat, kNCHW, XPUMhca_FP32, def)
131+
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
132+
.BindInput("Embedding", {LiteType::GetTensorTy(TARGET(kXPU))})
133+
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
134+
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
135+
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
136+
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
137+
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
138+
.Finalize();
139+
REGISTER_LITE_KERNEL(
140+
__xpu__multihead_cross_attn, kXPU, kFP16, kNCHW, XPUMhca_FP16, def)
141+
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
142+
.BindInput("Embedding",
143+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
144+
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
145+
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
146+
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
147+
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
148+
.BindOutput("Output",
149+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
150+
.Finalize();
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "lite/backends/xpu/xpu_header_sitter.h"
19+
#include "lite/core/kernel.h"
20+
#include "lite/core/op_registry.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace kernels {
25+
namespace xpu {
26+
27+
template <typename InType, PrecisionType PType>
28+
class XPUMhcaCompute : public KernelLite<TARGET(kXPU), PType> {
29+
public:
30+
using param_t = operators::XPUMhcaParam;
31+
32+
virtual void PrepareForRun();
33+
34+
virtual void Run();
35+
36+
virtual ~XPUMhcaCompute() = default;
37+
38+
private:
39+
std::vector<const int16_t *> arg_fc_weight_int16_;
40+
std::vector<const float *> arg_fc_bias_;
41+
std::vector<const float *> arg_ln_scale_;
42+
std::vector<const float *> arg_ln_bias_;
43+
std::vector<const float *> fc_weight_max_;
44+
XPUScratchPadGuard weight_max_guard_;
45+
46+
template <typename T>
47+
std::vector<const T *> *GetWeight() {
48+
LOG(FATAL) << "Invalid Weight Type";
49+
return nullptr;
50+
}
51+
52+
std::vector<const int16_t *> *GetWeight() { return &arg_fc_weight_int16_; }
53+
54+
void PrepareWeightMax(const std::vector<lite::Tensor *> &weight_max,
55+
int max_ptr_len,
56+
std::vector<const float *> *max_xpu_ptrs);
57+
};
58+
59+
} // namespace xpu
60+
} // namespace kernels
61+
} // namespace lite
62+
} // namespace paddle
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/__xpu__multihead_self_attn_compute.h"
16+
#include <vector>
17+
#include "lite/backends/xpu/xpu_header_sitter.h"
18+
#include "lite/core/op_registry.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace xpu {
24+
25+
template <typename T>
26+
static std::vector<const T*> prepare_weight(
27+
const std::vector<lite::Tensor*>& fc_weight) {
28+
std::vector<const T*> result;
29+
for (auto* weight : fc_weight) {
30+
result.push_back(reinterpret_cast<const T*>(weight->data<float>()));
31+
}
32+
return result;
33+
}
34+
35+
template <typename InType, PrecisionType PType>
36+
void XPUMhsaCompute<InType, PType>::PrepareWeightMax(
37+
const std::vector<lite::Tensor*>& weight_max,
38+
int max_ptr_len,
39+
std::vector<const float*>* max_xpu_ptrs) {
40+
int max_value_num = 0;
41+
for (auto max_tensor : weight_max) {
42+
max_value_num += max_tensor->numel();
43+
}
44+
VLOG(3) << "Total weight max value number: " << max_value_num;
45+
weight_max_guard_ =
46+
TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float));
47+
float* weight_max_ptr = reinterpret_cast<float*>(weight_max_guard_->addr_);
48+
49+
int offset = 0;
50+
for (auto max_tensor : weight_max) {
51+
float* cur_weight_max_ptr = weight_max_ptr + offset;
52+
auto len = max_tensor->numel();
53+
VLOG(6) << "weight max value: " << max_tensor->data<float>()[0] << " "
54+
<< max_tensor->data<float>()[len - 1];
55+
std::vector<float> cpu_max(max_ptr_len, max_tensor->data<float>()[0]);
56+
lite::TargetWrapperXPU::MemcpySync(cur_weight_max_ptr,
57+
cpu_max.data(),
58+
sizeof(float) * max_ptr_len,
59+
IoDirection::HtoD);
60+
max_xpu_ptrs->push_back(cur_weight_max_ptr);
61+
offset += max_ptr_len;
62+
}
63+
}
64+
65+
template <typename InType, PrecisionType PType>
66+
void XPUMhsaCompute<InType, PType>::PrepareForRun() {
67+
auto& ctx = this->ctx_->template As<XPUContext>();
68+
auto& param = this->template Param<param_t>();
69+
// prepare bias
70+
for (auto* fc_bias : param.fc_bias) {
71+
arg_fc_bias_.push_back(fc_bias->template data<float>());
72+
}
73+
// prepare scale
74+
for (auto* ln_scale : param.ln_scale) {
75+
arg_ln_scale_.push_back(ln_scale->template data<float>());
76+
}
77+
// prepare ln_bias
78+
for (auto* ln_bias : param.ln_bias) {
79+
arg_ln_bias_.push_back(ln_bias->template data<float>());
80+
}
81+
arg_fc_weight_int16_ = prepare_weight<int16_t>(param.fc_weight);
82+
const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size();
83+
PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_);
84+
}
85+
86+
template <typename InType, PrecisionType PType>
87+
void XPUMhsaCompute<InType, PType>::Run() {
88+
// TODO(shenyijun): The compute of this op will be adapted to XFT interface
89+
// later on.
90+
//
91+
// auto& param = this->template Param<param_t>();
92+
// auto& ctx = this->ctx_->template As<XPUContext>();
93+
// const InType* in = param.input->template data<InType>();
94+
// InType* out = param.output->template mutable_data<InType>(TARGET(kXPU));
95+
// int batch = static_cast<int>(param.input->dims()[0]);
96+
// int seqlen = static_cast<int>(param.input->dims()[1]);
97+
// int r = xdnn::unet_mhsa_fusion<InType, int16_t, InType, int16_t>(
98+
// ctx.GetRawContext(),
99+
// in,
100+
// *(XPUMhsaCompute::GetWeight<int16_t>()),
101+
// out,
102+
// arg_fc_bias_,
103+
// arg_ln_scale_,
104+
// arg_ln_bias_,
105+
// fc_weight_max_,
106+
// batch,
107+
// param.head_num,
108+
// param.size_per_head,
109+
// seqlen,
110+
// param.hidden_dim);
111+
// CHECK_EQ(r, 0);
112+
}
113+
114+
} // namespace xpu
115+
} // namespace kernels
116+
} // namespace lite
117+
} // namespace paddle
118+
119+
namespace xpu = paddle::lite::kernels::xpu;
120+
121+
using XPUMhsa_FP32 = xpu::XPUMhsaCompute<float, PRECISION(kFloat)>;
122+
using XPUMhsa_FP16 = xpu::XPUMhsaCompute<float16, PRECISION(kFP16)>;
123+
REGISTER_LITE_KERNEL(
124+
__xpu__multihead_self_attn, kXPU, kFloat, kNCHW, XPUMhsa_FP32, def)
125+
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
126+
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
127+
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
128+
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
129+
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
130+
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
131+
.Finalize();
132+
REGISTER_LITE_KERNEL(
133+
__xpu__multihead_self_attn, kXPU, kFP16, kNCHW, XPUMhsa_FP16, def)
134+
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
135+
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
136+
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
137+
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
138+
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
139+
.BindOutput("Output",
140+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
141+
.Finalize();

0 commit comments

Comments
 (0)