Skip to content

Commit 2098ee0

Browse files
authored
[xpu] Fc int31 (#7514)
* [xpu] fix continuous encoder fuse and fc max size * [xpu] refactor fc int31 for KL2
1 parent b80b8b2 commit 2098ee0

File tree

3 files changed

+34
-31
lines changed

3 files changed

+34
-31
lines changed

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -665,21 +665,24 @@ class XPUMultiEncoderFuser {
665665

666666
void operator()(SSAGraph* graph) {
667667
std::vector<Node*> all_encoders;
668-
for (auto* node : graph->StmtTopologicalOrder()) {
669-
CHECK(node->IsStmt());
670-
if (node->stmt()->op_info()->Type() == "single_encoder") {
671-
if (all_encoders.empty() ||
672-
IsDirectPredecessorOf(all_encoders.back(), node)) {
673-
all_encoders.push_back(node);
674-
} else {
675-
break;
668+
// if no node linked from all_encoders.back(), search is over
669+
int encoder_num = 0;
670+
do {
671+
encoder_num = all_encoders.size();
672+
for (auto* node : graph->StmtTopologicalOrder()) {
673+
CHECK(node->IsStmt());
674+
if (node->stmt()->op_info()->Type() == "single_encoder") {
675+
if (all_encoders.empty() ||
676+
IsDirectPredecessorOf(all_encoders.back(), node)) {
677+
all_encoders.push_back(node);
678+
}
676679
}
677680
}
678-
}
679-
VLOG(3) << "Found continuous " << all_encoders.size() << " single_encoder";
681+
} while (encoder_num != all_encoders.size());
680682
if (all_encoders.size() == 0) {
681683
return;
682684
}
685+
VLOG(3) << "Found continuous " << all_encoders.size() << " single_encoder";
683686

684687
const bool enable_int8 =
685688
all_encoders[0]->stmt()->op_info()->HasAttr("enable_int8") &&

lite/kernels/xpu/__xpu__fc_compute.cc

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ void XPUFcCompute::Run() {
109109
int n = param.w->dims()[1];
110110
bool quant_int8 = param.quant_w_max > 0.f;
111111

112+
param.output_max->Resize({lite::XPU_QUANT_SCALE_NUM});
112113
float* output_max = quant_int8
113114
? nullptr
114115
: param.output_max->mutable_data<float>(TARGET(kXPU));
@@ -125,26 +126,26 @@ void XPUFcCompute::Run() {
125126
}
126127
// TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor
127128
if (param.precision == "int31") {
128-
int r = xdnn::fc_int31(
129-
ctx.GetRawContext(), /* context */
130-
false, /* TransA */
131-
true, /* TransB */
132-
m, /* m */
133-
n, /* n */
134-
k, /* k */
135-
1.0f, /* alpha */
136-
param.input->data<float>(), /* A */
137-
nullptr, /* max_a ptr */
138-
reinterpret_cast<const float*>(quant_weight_guard_->addr_), /* B */
139-
w_max, /* max_b */
140-
0.0f, /* beta */
141-
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
142-
nullptr, /* max_c ptr */
143-
bias, /* bias */
144-
act /* act_type */);
145-
CHECK_EQ(r, 0);
146-
r = xdnn::findmax<float>(
147-
ctx.GetRawContext(), param.output->data<float>(), m * n, output_max);
129+
int r = xdnn::fc_fusion<float, float, float, int>(
130+
ctx.GetRawContext(), // ctx
131+
param.input->data<float>(), // x
132+
reinterpret_cast<const float*>(quant_weight_guard_->addr_), // w
133+
param.output->mutable_data<float>(TARGET(kXPU)), // y
134+
m, // m
135+
n, // n
136+
k, // k
137+
false, // x_trans
138+
true, // w_trans
139+
input_max, // x_maxptr
140+
reinterpret_cast<const float*>(weight_max_guard_->addr_), // w_maxptr
141+
output_max, // y_maxptr
142+
k, // ldx
143+
k, // ldw
144+
n, // ldy
145+
1.0f, // alpha
146+
0.0f, // beta
147+
bias, // bias
148+
act);
148149
CHECK_EQ(r, 0);
149150
} else if (param.precision == "int16") {
150151
int r = 0;

lite/operators/__xpu__fc_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ bool XPUFcOp::InferShapeImpl() const {
6262
}
6363
output_dims[in_num_col_dims] = w_dims_1;
6464
param_.output->Resize(output_dims);
65-
param_.output_max->Resize({4});
6665

6766
// share LoD
6867
param_.output->set_lod(param_.input->lod());

0 commit comments

Comments
 (0)