Skip to content

Commit 229f75c

Browse files
committed
[XPU] remove cast for elementwise_add(float, bf16/fp16)
1 parent 73695a8 commit 229f75c

3 files changed

Lines changed: 45 additions & 15 deletions

File tree

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ if(NOT DEFINED XPU_XDNN_BASE_DATE)
3232
set(XPU_XDNN_BASE_DATE "20240327")
3333
endif()
3434
if(NOT DEFINED XPU_XHPC_BASE_DATE)
35-
set(XPU_XHPC_BASE_DATE "20240511")
35+
set(XPU_XHPC_BASE_DATE "20240514")
3636
endif()
3737
set(XPU_XCCL_BASE_VERSION "1.2.0.5")
3838
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/kernels/xpu/elementwise_add_kernel.cc

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,44 @@ void AddKernel(const Context& dev_ctx,
3737
if (x.dtype() == phi::DataType::FLOAT32 &&
3838
(y.dtype() == phi::DataType::BFLOAT16 ||
3939
y.dtype() == phi::DataType::FLOAT16)) {
40-
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
41-
using XPUType = typename XPUTypeTrait<Type>::Type;
42-
auto f = [](xpu::Context* ctx,
43-
const XPUType* x,
44-
const XPUType* y,
45-
XPUType* z,
46-
const std::vector<int>& xshape,
47-
const std::vector<int>& yshape) {
48-
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
49-
};
50-
auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
51-
XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1, out, f);
40+
auto dev_version =
41+
phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId());
42+
if (dev_version >= phi::backends::xpu::XPUVersion::XPU3 &&
43+
x.dims() == y.dims()) {
44+
dev_ctx.template Alloc<float>(out);
45+
46+
const float* x_data = x.data<float>();
47+
float* z_data = out->data<float>();
48+
49+
int ret = xpu::SUCCESS;
50+
if (y.dtype() == phi::DataType::BFLOAT16) {
51+
using YType = DataTypeToCppType<phi::DataType::BFLOAT16>::type;
52+
using XPUYType = typename XPUTypeTrait<YType>::Type;
53+
auto y_data = reinterpret_cast<const XPUYType*>(y.data<YType>());
54+
ret = xpu::add_mul_type<float, XPUYType, float>(
55+
dev_ctx.x_context(), x_data, y_data, z_data, x.numel());
56+
} else {
57+
using YType = DataTypeToCppType<phi::DataType::FLOAT16>::type;
58+
using XPUYType = typename XPUTypeTrait<YType>::Type;
59+
auto y_data = reinterpret_cast<const XPUYType*>(y.data<YType>());
60+
ret = xpu::add_mul_type<float, XPUYType, float>(
61+
dev_ctx.x_context(), x_data, y_data, z_data, x.numel());
62+
}
63+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "add_mul_type");
64+
} else {
65+
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
66+
using XPUType = typename XPUTypeTrait<Type>::Type;
67+
auto f = [](xpu::Context* ctx,
68+
const XPUType* x,
69+
const XPUType* y,
70+
XPUType* z,
71+
const std::vector<int>& xshape,
72+
const std::vector<int>& yshape) {
73+
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
74+
};
75+
auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
76+
XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1, out, f);
77+
}
5278
} else {
5379
using XPUType = typename XPUTypeTrait<T>::Type;
5480

python/paddle/base/framework.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8246,8 +8246,8 @@ def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
82468246

82478247

82488248
def can_skip_promote(op, device):
8249-
# Only GPU elementwise_add kernel supports the pattern "float + half".
8250-
if device != 'GPU':
8249+
# Only GPU/XPU elementwise_add kernel supports the pattern "float + half".
8250+
if device not in ['GPU', 'XPU']:
82518251
return False
82528252
if op.type != "elementwise_add":
82538253
return False
@@ -8268,6 +8268,10 @@ def process_type_promotion(program):
82688268
_current_expected_place(), core.CUDAPlace
82698269
):
82708270
device = 'GPU'
8271+
elif core.is_compiled_with_xpu() and isinstance(
8272+
_current_expected_place(), core.XPUPlace
8273+
):
8274+
device = 'XPU'
82718275
org_program = program
82728276
if program is None:
82738277
program = default_main_program()

0 commit comments

Comments
 (0)