diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 84df3c4b139aa4..f4c6d2db94e90c 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -71,7 +71,6 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant>; template struct SetConstant>; - #endif #define DEFINE_CPU_TRANS(RANK) \ @@ -246,7 +245,12 @@ void set_constant(const phi::DeviceContext& context, // tensor->place().apply_visitor(func); phi::VisitPlace(tensor->place(), func); #elif defined(PADDLE_WITH_XPU) - func(phi::XPUPlace()); + if (context.GetPlace().GetType() == phi::AllocationType::XPU) { + func(phi::XPUPlace()); + return; + } else { + func(phi::CPUPlace()); + } #else func(phi::CPUPlace()); #endif diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index 5e5834bf91e307..2e1013ec7fc1be 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -143,9 +143,9 @@ struct TensorSetConstantXPU { auto* ctx = phi::DeviceContextPool::Instance().Get(place_); auto begin = ctx->Alloc(tensor_); int numel = tensor_->numel(); - if (((std::is_same::value) || - (std::is_same::value)) && - (place_ == phi::XPUPlace())) { + if ((std::is_same::value) || + (std::is_same::value) || + (std::is_same::value)) { using XPUType = typename XPUTypeTrait::Type; auto* dev_ctx = static_cast(ctx); int r = xpu::constant(dev_ctx->x_context(),