Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 67 additions & 44 deletions paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
// limitations under the License.

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"

namespace phi {
Expand All @@ -32,60 +30,85 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* reserve_space) {
using XPUType = typename XPUTypeTrait<T>::Type;
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);

auto x_mat_dims =
common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1);

// (M * K) * (K * N) for new api use
// int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
// int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
// int64_t N = trans_y ? y->dims()[0] : y->dims()[1];

// 调用新接口,这里先分开调用,等待qingpen的新接口
int r = 0;
int r = xpu::SUCCESS;
xpu::Activation_t act = xpu::Activation_t::LINEAR;
if (activation == "relu") {
act = xpu::Activation_t::RELU;
} else if (activation == "gelu") {
act = xpu::Activation_t::GELU;
}
// fc + bias + act
// 1. fc
phi::XpuFcInfo fc_info;

phi::GetFCInfo(x_mat_dims, y.dims(), trans_x, trans_y, &fc_info);
xpu::Context* xpu_ctx = dev_ctx.x_context();

const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
auto* out_tmp_ptr = dev_ctx.template Alloc<T>(out);
XPUType* out_ptr = reinterpret_cast<XPUType*>(out_tmp_ptr);
xpu::ctx_guard RAII_GUARD(xpu_ctx);
XPUType* fc_out_ptr = RAII_GUARD.alloc_l3_or_gm<XPUType>(out->numel());
phi::MatMulXPUFunction<XPUType>(
xpu_ctx, x_ptr, y_ptr, fc_out_ptr, fc_info, 1.0f);
XPUType* bias_out_ptr = out_ptr;
if (activation != "none" && reserve_space) {
auto* bias_out_temp_ptr = dev_ctx.template Alloc<T>(reserve_space);
bias_out_ptr = reinterpret_cast<XPUType*>(bias_out_temp_ptr);
}
// 2 bias
const XPUType* bias_ptr = reinterpret_cast<const XPUType*>(bias.data<T>());
r = xpu::broadcast_add(xpu_ctx,
fc_out_ptr,
bias_ptr,
bias_out_ptr,
{fc_info.m, fc_info.n},
{fc_info.n});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
// 3 act
if (activation == "relu") {
r = xpu::relu(xpu_ctx, bias_out_ptr, out_ptr, out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu");
} else if (activation == "gelu") {
r = xpu::gelu(xpu_ctx, bias_out_ptr, out_ptr, out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
XPUType* out_ptr = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));

decltype(&xpu_fc_wrapper<XPUType, int16_t>) fc_api_list[5] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的fc_api_list,看起来内容和phi/kernels/xpu/xpu_api_wrapper.h里面的MatMulXPUFunction函数中定义的一样?
现在这么写没问题,不过有没有更好或者更优雅的办法能减少重复代码?以及如果以后有更新,两遍没同步的话,不知道会不会导致奇怪的问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在xpu_api_wrapper.h的更改还在等待xhpc更新产出,我这边先mark一下,等我下个pr更新xpu_api_wrapper.h的时候把这边也更新一下好了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在xpu_api_wrapper.h的更改还在等待xhpc更新产出,我这边先mark一下,等我下个pr更新xpu_api_wrapper.h的时候把这边也更新一下好了。

感觉可以把根据fccal_type选择和运行fc_fusion和fc_batched单独抽取到一个函数里

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉可以把根据fccal_type选择和运行fc_fusion和fc_batched单独抽取到一个函数里

xpu_api_wrapper.h 有个MatMulXPUFunction就是实现这个功能的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过看了一下那个函数只是根据batch_size来选择调用fc_fusion或fc_batched,不知道能不能满足你这边的情况

&xpu_fc_wrapper<XPUType, int16_t>,
&xpu_fc_wrapper<XPUType, int32_t>,
&xpu_fc_wrapper<XPUType, float>,
&xpu_fc_wrapper<XPUType, int_with_ll_t>,
&xpu_fc_wrapper<XPUType, tfloat32>,
};

auto fccal_type = FCCalcType<XPUType>();

auto fc_api = fc_api_list[fccal_type];

// fc + bias + act
phi::XpuFcInfo fc_info;
auto mat_x_dims =
common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1);
auto mat_y_dims = y.dims();
phi::GetFCInfo(mat_x_dims, mat_y_dims, trans_x, trans_y, &fc_info);
int batch_size = fc_info.bs;
int m = fc_info.m;
int n = fc_info.n;
int k = fc_info.k;
int ldx = fc_info.stride_x;
int ldy = fc_info.stride_y;
int ldout = fc_info.stride_out;
float* max_x = fc_info.max_x;
float* max_y = fc_info.max_y;
float* max_out = fc_info.max_out;

PADDLE_ENFORCE_LE(
batch_size,
1,
errors::InvalidArgument(
"FusedGemm do not support batched fc now, but got batch size %d.",
batch_size));
Copy link
Contributor Author

@lj970926 lj970926 Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里选择不支持batched_fc有以下几个考虑:

  1. GPU和单测里目前均没有batched_fc的支持
  2. fc_batched目前不支持bias和act的融合
  3. 该kernel目前只会在FusedLinear中调用,由于weights是2维所以不会有batched_fc

const float* bias_fp32 = reinterpret_cast<const float*>(bias.data<T>());
if (!std::is_same<T, float>::value) {
// TODO(lijin23): Now xblas and xdnn support fp32 bias only, may be removed
// in the future.
bias_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias.numel());
r = xpu::cast<XPUType, float>(
xpu_ctx, bias_ptr, const_cast<float*>(bias_fp32), bias.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
fc_api(xpu_ctx,
x_ptr,
y_ptr,
out_ptr,
m,
n,
k,
trans_x,
trans_y,
max_x,
max_y,
max_out,
ldx,
ldy,
ldout,
1.0f,
0,
bias_fp32,
act);
}

} // namespace fusion
Expand Down