Skip to content

Commit 336c0e8

Browse files
committed
add xpu kernel for fused_bias_act
1 parent a8691f8 commit 336c0e8

2 files changed

Lines changed: 140 additions & 0 deletions

File tree

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,8 @@ XPUOpMap& get_kl2_ops() {
11781178
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
11791179
{"fused_attention_grad",
11801180
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
1181+
{"fused_bias_act",
1182+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
11811183
{"fused_feedforward",
11821184
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
11831185
{"fused_feedforward_grad",
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/backends/xpu/xpu_context.h"
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
namespace fusion {
22+
23+
template <typename T>
24+
static void DispatchComputeImpl(const phi::XPUContext *xpu_ctx,
25+
const DenseTensor &x,
26+
const DenseTensor *bias,
27+
const DenseTensor &dequant_scales,
28+
const DenseTensor &shift,
29+
const DenseTensor &smooth,
30+
const std::string &act_method,
31+
const float quant_scale,
32+
const int quant_round_type,
33+
const float quant_max_bound,
34+
const float quant_min_bound,
35+
DenseTensor *out) {
36+
return;
37+
}
38+
39+
template <typename T>
40+
static void ComputeImpl(const phi::XPUContext *xpu_ctx,
41+
const DenseTensor &x,
42+
const paddle::optional<DenseTensor> &bias,
43+
const std::string &act_method,
44+
DenseTensor *out) {
45+
using XPUType = typename XPUTypeTrait<T>::Type;
46+
int rows = x.dims()[0];
47+
int cols = x.dims()[1];
48+
int r = 0;
49+
if (bias) {
50+
r = baidu::xpu::api::broadcast_add<XPUType>(
51+
xpu_ctx->x_context(),
52+
reinterpret_cast<const XPUType *>(x.data<T>()),
53+
reinterpret_cast<const XPUType *>(bias.get().data<T>()),
54+
reinterpret_cast<XPUType *>(const_cast<T *>(x.data<T>())),
55+
{rows, cols},
56+
{1, cols});
57+
PD_CHECK(r == 0, "baidu::xpu::api::broadcast_add failed.");
58+
}
59+
cols = act_method == "swiglu" ? cols / 2 : cols;
60+
if (act_method == "geglu") {
61+
PD_THROW(
62+
"NOT supported GeGLU. "
63+
"Currently Only Support SwiGLU, GeLU, ReLU");
64+
} else if (act_method == "swiglu") {
65+
r = baidu::xpu::api::swiglu<XPUType>(
66+
xpu_ctx->x_context(),
67+
reinterpret_cast<const XPUType *>(x.data<T>()),
68+
reinterpret_cast<XPUType *>(const_cast<T *>(out->data<T>())),
69+
{rows, 1, cols},
70+
2,
71+
true);
72+
PD_CHECK(r == 0, "baidu::xpu::api::swiglu failed.");
73+
} else if (act_method == "gelu") {
74+
r = baidu::xpu::api::gelu<XPUType>(
75+
xpu_ctx->x_context(),
76+
reinterpret_cast<const XPUType *>(x.data<T>()),
77+
reinterpret_cast<XPUType *>(const_cast<T *>(out->data<T>())),
78+
rows * cols);
79+
PD_CHECK(r == 0, "baidu::xpu::api::gelu failed.");
80+
} else if (act_method == "relu") {
81+
r = baidu::xpu::api::relu<XPUType>(
82+
xpu_ctx->x_context(),
83+
reinterpret_cast<const XPUType *>(x.data<T>()),
84+
reinterpret_cast<XPUType *>(const_cast<T *>(out->data<T>())),
85+
rows * cols);
86+
PD_CHECK(r == 0, "baidu::xpu::api::relu failed.");
87+
} else {
88+
PD_THROW(
89+
"NOT supported. "
90+
"Currently Only Support SwiGLU, GeLU, ReLU");
91+
}
92+
return;
93+
}
94+
95+
template <typename T, typename Context>
96+
void FusedBiasActKernel(const Context &dev_ctx,
97+
const DenseTensor &x,
98+
const paddle::optional<DenseTensor> &bias,
99+
const paddle::optional<DenseTensor> &dequant_scales,
100+
const paddle::optional<DenseTensor> &shift,
101+
const paddle::optional<DenseTensor> &smooth,
102+
const std::string &act_method,
103+
const std::string &compute_dtype,
104+
float quant_scale,
105+
int quant_round_type,
106+
float quant_max_bound,
107+
float quant_min_bound,
108+
DenseTensor *out) {
109+
auto xpu_ctx = static_cast<const phi::XPUContext *>(&dev_ctx);
110+
dev_ctx.template Alloc<T>(out);
111+
112+
if (dequant_scales && dequant_scales.get().numel() > 0) {
113+
return DispatchComputeImpl<T>(xpu_ctx,
114+
x,
115+
bias ? &(bias.get()) : nullptr,
116+
dequant_scales.get(),
117+
shift.get(),
118+
smooth.get(),
119+
act_method,
120+
quant_scale,
121+
quant_round_type,
122+
quant_max_bound,
123+
quant_min_bound,
124+
out);
125+
} else {
126+
return ComputeImpl<T>(xpu_ctx, x, bias, act_method, out);
127+
}
128+
}
129+
130+
} // namespace fusion
131+
} // namespace phi
132+
133+
PD_REGISTER_KERNEL(fused_bias_act,
134+
XPU,
135+
ALL_LAYOUT,
136+
phi::fusion::FusedBiasActKernel,
137+
float,
138+
phi::dtype::float16) {}

0 commit comments

Comments
 (0)