Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ if(NOT DEFINED XPU_XDNN_BASE_DATE)
set(XPU_XDNN_BASE_DATE "20240327")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20240511")
set(XPU_XHPC_BASE_DATE "20240514")
endif()
set(XPU_XCCL_BASE_VERSION "1.2.0.5")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
50 changes: 38 additions & 12 deletions paddle/phi/kernels/xpu/elementwise_add_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,44 @@ void AddKernel(const Context& dev_ctx,
if (x.dtype() == phi::DataType::FLOAT32 &&
(y.dtype() == phi::DataType::BFLOAT16 ||
y.dtype() == phi::DataType::FLOAT16)) {
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
using XPUType = typename XPUTypeTrait<Type>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1, out, f);
auto dev_version =
phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId());
if (dev_version >= phi::backends::xpu::XPUVersion::XPU3 &&
x.dims() == y.dims()) {
dev_ctx.template Alloc<float>(out);

const float* x_data = x.data<float>();
float* z_data = out->data<float>();

int ret = xpu::SUCCESS;
if (y.dtype() == phi::DataType::BFLOAT16) {
using YType = DataTypeToCppType<phi::DataType::BFLOAT16>::type;
using XPUYType = typename XPUTypeTrait<YType>::Type;
auto y_data = reinterpret_cast<const XPUYType*>(y.data<YType>());
ret = xpu::add_mul_type<float, XPUYType, float>(
dev_ctx.x_context(), x_data, y_data, z_data, x.numel());
} else {
using YType = DataTypeToCppType<phi::DataType::FLOAT16>::type;
using XPUYType = typename XPUTypeTrait<YType>::Type;
auto y_data = reinterpret_cast<const XPUYType*>(y.data<YType>());
ret = xpu::add_mul_type<float, XPUYType, float>(
dev_ctx.x_context(), x_data, y_data, z_data, x.numel());
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "add_mul_type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的报错信息支持common::errors::XX 类型提示么?如果支持的话,建议加一下报错类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的, 我研究下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前原本的XPUElementwise里面也是调用的PADDLE_ENFORCE_XDNN_SUCCESS,包括其它XPU算子也是用的PADDLE_ENFORCE_XDNN_SUCCESS,至于common::errors::XX报错信息的建议后续可以专门批量去改,看看是否能加到PADDLE_ENFORCE_XDNN_SUCCESS宏里面

} else {
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
using XPUType = typename XPUTypeTrait<Type>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1, out, f);
}
} else {
using XPUType = typename XPUTypeTrait<T>::Type;

Expand Down
8 changes: 6 additions & 2 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -8246,8 +8246,8 @@ def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):


def can_skip_promote(op, device):
# Only GPU elementwise_add kernel supports the pattern "float + half".
if device != 'GPU':
# Only GPU/XPU elementwise_add kernel supports the pattern "float + half".
if device not in ['GPU', 'XPU']:
return False
if op.type != "elementwise_add":
return False
Expand All @@ -8268,6 +8268,10 @@ def process_type_promotion(program):
_current_expected_place(), core.CUDAPlace
):
device = 'GPU'
elif core.is_compiled_with_xpu() and isinstance(
_current_expected_place(), core.XPUPlace
):
device = 'XPU'
org_program = program
if program is None:
program = default_main_program()
Expand Down