diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 55aae9f24c1a61..993cad302e2f21 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -1178,6 +1178,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_attention_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_bias_act", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward_grad", diff --git a/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc new file mode 100644 index 00000000000000..d36d7416a023ae --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +static void DispatchComputeImpl(const phi::XPUContext *xpu_ctx, + const DenseTensor &x, + const DenseTensor *bias, + const DenseTensor &dequant_scales, + const DenseTensor &shift, + const DenseTensor &smooth, + const std::string &act_method, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out) { + PADDLE_THROW( + phi::errors::Unimplemented("fused_bias_act with smooth " + "quant on xpu is not implemented yet.")); +} + +template +static void ComputeImpl(const phi::XPUContext *xpu_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const std::string &act_method, + DenseTensor *out) { + using XPUType = typename XPUTypeTrait::Type; + int rows = x.dims()[0]; + int cols = x.dims()[1]; + int r = 0; + if (bias) { + r = baidu::xpu::api::broadcast_add( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(bias.get().data()), + reinterpret_cast(const_cast(x.data())), + {rows, cols}, + {1, cols}); + PD_CHECK(r == 0, "baidu::xpu::api::broadcast_add failed."); + } + if (act_method == "geglu") { + PD_THROW( + "NOT supported GeGLU. " + "Currently Only Support SwiGLU, GeLU, ReLU"); + } else if (act_method == "swiglu") { + r = baidu::xpu::api::swiglu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + {rows, cols}, + 1, + true); + PD_CHECK(r == 0, "baidu::xpu::api::swiglu failed."); + } else if (act_method == "gelu") { + r = baidu::xpu::api::gelu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + rows * cols); + PD_CHECK(r == 0, "baidu::xpu::api::gelu failed."); + } else if (act_method == "relu") { + r = baidu::xpu::api::relu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + rows * cols); + PD_CHECK(r == 0, "baidu::xpu::api::relu failed."); + } else { + PD_THROW( + "NOT supported. " + "Currently Only Support SwiGLU, GeLU, ReLU"); + } +} + +template +void FusedBiasActKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const paddle::optional &dequant_scales, + const paddle::optional &shift, + const paddle::optional &smooth, + const std::string &act_method, + const std::string &compute_dtype, + float quant_scale, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor *out) { + auto xpu_ctx = static_cast(&dev_ctx); + dev_ctx.template Alloc(out); + + if (dequant_scales && dequant_scales.get().numel() > 0) { + return DispatchComputeImpl(xpu_ctx, + x, + bias ? &(bias.get()) : nullptr, + dequant_scales.get(), + shift.get(), + smooth.get(), + act_method, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out); + } else { + return ComputeImpl(xpu_ctx, x, bias, act_method, out); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_bias_act, + XPU, + ALL_LAYOUT, + phi::fusion::FusedBiasActKernel, + float, + phi::dtype::float16) {}