Skip to content

Commit 37c12d1

Browse files
committed
[xpu] refactor fc int31 for KL2; test=develop
1 parent 88558c7 commit 37c12d1

1 file changed

Lines changed: 20 additions & 20 deletions

File tree

lite/kernels/xpu/__xpu__fc_compute.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,26 @@ void XPUFcCompute::Run() {
125125
}
126126
// TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor
127127
if (param.precision == "int31") {
128-
int r = xdnn::fc_int31(
129-
ctx.GetRawContext(), /* context */
130-
false, /* TransA */
131-
true, /* TransB */
132-
m, /* m */
133-
n, /* n */
134-
k, /* k */
135-
1.0f, /* alpha */
136-
param.input->data<float>(), /* A */
137-
nullptr, /* max_a ptr */
138-
reinterpret_cast<const float*>(quant_weight_guard_->addr_), /* B */
139-
w_max, /* max_b */
140-
0.0f, /* beta */
141-
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
142-
nullptr, /* max_c ptr */
143-
bias, /* bias */
144-
act /* act_type */);
145-
CHECK_EQ(r, 0);
146-
r = xdnn::findmax<float>(
147-
ctx.GetRawContext(), param.output->data<float>(), m * n, output_max);
128+
int r = xdnn::fc_fusion<float, float, float, int>(
129+
ctx.GetRawContext(), // ctx
130+
param.input->data<float>(), // x
131+
reinterpret_cast<const float*>(quant_weight_guard_->addr_), // w
132+
param.output->mutable_data<float>(TARGET(kXPU)), // y
133+
m, // m
134+
n, // n
135+
k, // k
136+
false, // x_trans
137+
true, // w_trans
138+
input_max, // x_maxptr
139+
reinterpret_cast<const float*>(weight_max_guard_->addr_), // w_maxptr
140+
output_max, // y_maxptr
141+
k, // ldx
142+
k, // ldw
143+
n, // ldy
144+
1.0f, // alpha
145+
0.0f, // beta
146+
bias, // bias
147+
act);
148148
CHECK_EQ(r, 0);
149149
} else if (param.precision == "int16") {
150150
int r = 0;

0 commit comments

Comments
 (0)