@@ -60,14 +60,34 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
6060 input_y, nullptr , m, n, n, false };
6161 baidu::xpu::xblas::FcFusionTensor<XPUType> tensor_y{
6262 input_y, nullptr , m, n, n, false };
63+ DenseTensor weight_scale_fp32;
64+ if (weight_scale.dtype () != phi::DataType::FLOAT32 &&
65+ weight_scale.dims ().size () != 0 ) {
66+ weight_scale_fp32.Resize (weight_scale.dims ());
67+ dev_ctx.template Alloc <float >(&weight_scale_fp32);
68+ int r = baidu::xpu::api::cast<XPUType, float >(
69+ dev_ctx.x_context (),
70+ reinterpret_cast <const XPUType*>(weight_scale.data <T>()),
71+ weight_scale_fp32.data <float >(),
72+ weight_scale.numel ());
73+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " cast" );
74+ }
75+ const float * weight_scale_ptr = nullptr ;
76+ if (weight_scale.dims ().size () != 0 ) {
77+ if (weight_scale.dtype () == phi::DataType::FLOAT32) {
78+ weight_scale_ptr = weight_scale.data <float >();
79+ } else {
80+ weight_scale_ptr = weight_scale_fp32.data <float >();
81+ }
82+ }
6383 baidu::xpu::xblas::FcFusionEpilogue<float , float > epilogue{
6484 api::Activation_t::LINEAR,
6585 bias.is_initialized () ? (bias.get ().dtype () == phi::DataType::FLOAT16
6686 ? bias_fp32.data <float >()
6787 : bias.get ().data <float >())
6888 : nullptr ,
6989 nullptr ,
70- weight_scale. dims (). size () != 0 ? weight_scale. data < float >() : nullptr ,
90+ weight_scale_ptr ,
7191 0 ,
7292 1 ,
7393 nullptr };
0 commit comments