@@ -37,18 +37,44 @@ void AddKernel(const Context& dev_ctx,
3737 if (x.dtype () == phi::DataType::FLOAT32 &&
3838 (y.dtype () == phi::DataType::BFLOAT16 ||
3939 y.dtype () == phi::DataType::FLOAT16)) {
40- using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
41- using XPUType = typename XPUTypeTrait<Type>::Type;
42- auto f = [](xpu::Context* ctx,
43- const XPUType* x,
44- const XPUType* y,
45- XPUType* z,
46- const std::vector<int >& xshape,
47- const std::vector<int >& yshape) {
48- return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
49- };
50- auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
51- XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1 , out, f);
40+ auto dev_version =
41+ phi::backends::xpu::get_xpu_version (dev_ctx.GetPlace ().GetDeviceId ());
42+ if (dev_version >= phi::backends::xpu::XPUVersion::XPU3 &&
43+ x.dims () == y.dims ()) {
44+ dev_ctx.template Alloc <float >(out);
45+
46+ const float * x_data = x.data <float >();
47+ float * z_data = out->data <float >();
48+
49+ int ret = xpu::SUCCESS;
50+ if (y.dtype () == phi::DataType::BFLOAT16) {
51+ using YType = DataTypeToCppType<phi::DataType::BFLOAT16>::type;
52+ using XPUYType = typename XPUTypeTrait<YType>::Type;
53+ auto y_data = reinterpret_cast <const XPUYType*>(y.data <YType>());
54+ ret = xpu::add_mul_type<float , XPUYType, float >(
55+ dev_ctx.x_context (), x_data, y_data, z_data, x.numel ());
56+ } else {
57+ using YType = DataTypeToCppType<phi::DataType::FLOAT16>::type;
58+ using XPUYType = typename XPUTypeTrait<YType>::Type;
59+ auto y_data = reinterpret_cast <const XPUYType*>(y.data <YType>());
60+ ret = xpu::add_mul_type<float , XPUYType, float >(
61+ dev_ctx.x_context (), x_data, y_data, z_data, x.numel ());
62+ }
63+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " add_mul_type" );
64+ } else {
65+ using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
66+ using XPUType = typename XPUTypeTrait<Type>::Type;
67+ auto f = [](xpu::Context* ctx,
68+ const XPUType* x,
69+ const XPUType* y,
70+ XPUType* z,
71+ const std::vector<int >& xshape,
72+ const std::vector<int >& yshape) {
73+ return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
74+ };
75+ auto casted_y = phi::Cast<T>(dev_ctx, y, phi::DataType::FLOAT32);
76+ XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1 , out, f);
77+ }
5278 } else {
5379 using XPUType = typename XPUTypeTrait<T>::Type;
5480
0 commit comments