Skip to content

Commit 29d0104

Browse files
authored
[XPU] support fp16 weight_scale in op weight_only_linear (#73963)
1 parent 10c7198 commit 29d0104

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

paddle/phi/kernels/xpu/weight_only_linear_kernel.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,34 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
6060
input_y, nullptr, m, n, n, false};
6161
baidu::xpu::xblas::FcFusionTensor<XPUType> tensor_y{
6262
input_y, nullptr, m, n, n, false};
63+
DenseTensor weight_scale_fp32;
64+
if (weight_scale.dtype() != phi::DataType::FLOAT32 &&
65+
weight_scale.dims().size() != 0) {
66+
weight_scale_fp32.Resize(weight_scale.dims());
67+
dev_ctx.template Alloc<float>(&weight_scale_fp32);
68+
int r = baidu::xpu::api::cast<XPUType, float>(
69+
dev_ctx.x_context(),
70+
reinterpret_cast<const XPUType*>(weight_scale.data<T>()),
71+
weight_scale_fp32.data<float>(),
72+
weight_scale.numel());
73+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
74+
}
75+
const float* weight_scale_ptr = nullptr;
76+
if (weight_scale.dims().size() != 0) {
77+
if (weight_scale.dtype() == phi::DataType::FLOAT32) {
78+
weight_scale_ptr = weight_scale.data<float>();
79+
} else {
80+
weight_scale_ptr = weight_scale_fp32.data<float>();
81+
}
82+
}
6383
baidu::xpu::xblas::FcFusionEpilogue<float, float> epilogue{
6484
api::Activation_t::LINEAR,
6585
bias.is_initialized() ? (bias.get().dtype() == phi::DataType::FLOAT16
6686
? bias_fp32.data<float>()
6787
: bias.get().data<float>())
6888
: nullptr,
6989
nullptr,
70-
weight_scale.dims().size() != 0 ? weight_scale.data<float>() : nullptr,
90+
weight_scale_ptr,
7191
0,
7292
1,
7393
nullptr};

0 commit comments

Comments
 (0)