@@ -125,26 +125,26 @@ void XPUFcCompute::Run() {
125125 }
126126 // TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor
127127 if (param.precision == " int31" ) {
128- int r = xdnn::fc_int31 (
129- ctx.GetRawContext (), /* context */
130- false , /* TransA */
131- true , /* TransB */
132- m , /* m */
133- n , /* n */
134- k , /* k */
135- 1 . 0f , /* alpha */
136- param. input -> data < float >(), /* A */
137- nullptr , /* max_a ptr */
138- reinterpret_cast < const float *>(quant_weight_guard_-> addr_ ), /* B */
139- w_max , /* max_b */
140- 0 . 0f , /* beta */
141- param. output -> mutable_data < float >( TARGET ( kXPU )) , /* C */
142- nullptr , /* max_c ptr */
143- bias , /* bias */
144- act /* act_type */ );
145- CHECK_EQ (r, 0 );
146- r = xdnn::findmax< float >(
147- ctx. GetRawContext (), param. output -> data < float >(), m * n, output_max );
128+ int r = xdnn::fc_fusion< float , float , float , int > (
129+ ctx.GetRawContext (), // ctx
130+ param. input -> data < float >() , // x
131+ reinterpret_cast < const float *>(quant_weight_guard_-> addr_ ), // w
132+ param. output -> mutable_data < float >( TARGET ( kXPU )) , // y
133+ m , // m
134+ n , // n
135+ k , // k
136+ false , // x_trans
137+ true , // w_trans
138+ input_max, // x_maxptr
139+ reinterpret_cast < const float *>(weight_max_guard_-> addr_ ) , // w_maxptr
140+ output_max , // y_maxptr
141+ k , // ldx
142+ k, // ldw
143+ n , // ldy
144+ 1 . 0f , // alpha
145+ 0 . 0f , // beta
146+ bias, // bias
147+ act );
148148 CHECK_EQ (r, 0 );
149149 } else if (param.precision == " int16" ) {
150150 int r = 0 ;
0 commit comments