Skip to content

Commit 8468d87

Browse files
committed
add gemm_prepack_oth_int8 support GemmNBias test=develop
1 parent 8fa178a commit 8468d87

File tree

5 files changed

+258
-67
lines changed

5 files changed

+258
-67
lines changed

lite/backends/arm/math/gemm_prepacked_int8.cc

Lines changed: 168 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
258258
const float32_t* alpha,
259259
int is_relu,
260260
int k,
261-
int rem);
261+
int rem,
262+
int bias_direction);
262263
// clang-format off
263264
#ifdef __aarch64__
264265
#define GEMM_INT8_KERNEL \
@@ -802,9 +803,70 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
802803
"fmla v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale */ \
803804
"fmla v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale */ \
804805
"fmla v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale */ \
805-
"fmla v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale */
806+
"fmla v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale */ \
807+
"8: \n"
808+
809+
#define GEMM_TRANS_INT32_TO_FP32_N_Direction \
810+
"cmp %w[bias_direction], #2\n" /* skip N_Direction */ \
811+
"bne 7f\n" /* skip N_Direction */ \
812+
"ldp q8, q9, [%[bias]]\n" /* load bias */ \
813+
"ldp q10, q11, [%[bias], #32]\n" /* load bias */ \
814+
"ldp q12, q13, [%[scale]]\n" /* load scale */ \
815+
"ldp q14, q15, [%[scale], #32]\n" /* load scale */ \
816+
"scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \
817+
"scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \
818+
"scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \
819+
"scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \
820+
"scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \
821+
"scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \
822+
"scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \
823+
"scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \
824+
/* add bias */ \
825+
"mov v16.4s, v8.4s\n" \
826+
"mov v17.4s, v9.4s\n" \
827+
"mov v18.4s, v10.4s\n" \
828+
"mov v19.4s, v11.4s\n" \
829+
"mov v20.4s, v8.4s\n" \
830+
"mov v21.4s, v9.4s\n" \
831+
"mov v22.4s, v10.4s\n" \
832+
"mov v23.4s, v11.4s\n" \
833+
"fmla v16.4s, v0.4s, v12.4s\n" /* 00, mul scale */ \
834+
"fmla v17.4s, v1.4s, v13.4s\n" /* 01, mul scale */ \
835+
"fmla v18.4s, v2.4s, v14.4s\n" /* 02, mul scale */ \
836+
"fmla v19.4s, v3.4s, v15.4s\n" /* 03, mul scale */ \
837+
"fmla v20.4s, v4.4s, v12.4s\n" /* 10, mul scale */ \
838+
"fmla v21.4s, v5.4s, v13.4s\n" /* 11, mul scale */ \
839+
"fmla v22.4s, v6.4s, v14.4s\n" /* 12, mul scale */ \
840+
"fmla v23.4s, v7.4s, v15.4s\n" /* 13, mul scale */ \
841+
"scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \
842+
"scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \
843+
"scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \
844+
"scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \
845+
"scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \
846+
"scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \
847+
"scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \
848+
"scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \
849+
"mov v24.4s, v8.4s\n" \
850+
"mov v25.4s, v9.4s\n" \
851+
"mov v26.4s, v10.4s\n" \
852+
"mov v27.4s, v11.4s\n" \
853+
"mov v28.4s, v8.4s\n" \
854+
"mov v29.4s, v9.4s\n" \
855+
"mov v30.4s, v10.4s\n" \
856+
"mov v31.4s, v11.4s\n" \
857+
"fmla v24.4s, v0.4s, v12.4s\n" /* 20, mul scale */ \
858+
"fmla v25.4s, v1.4s, v13.4s\n" /* 21, mul scale */ \
859+
"fmla v26.4s, v2.4s, v14.4s\n" /* 22, mul scale */ \
860+
"fmla v27.4s, v3.4s, v15.4s\n" /* 23, mul scale */ \
861+
"fmla v28.4s, v4.4s, v12.4s\n" /* 30, mul scale */ \
862+
"fmla v29.4s, v5.4s, v13.4s\n" /* 31, mul scale */ \
863+
"fmla v30.4s, v6.4s, v14.4s\n" /* 32, mul scale */ \
864+
"fmla v31.4s, v7.4s, v15.4s\n" /* 33, mul scale */ \
865+
"b 8f \n" \
866+
"7: \n"
806867

807868
#define GEMM_INT8_FP32_OUT \
869+
GEMM_TRANS_INT32_TO_FP32_N_Direction \
808870
GEMM_TRANS_INT32_TO_FP32 \
809871
GEMM_INT8_RELU \
810872
GEMM_INT8_RELU6 \
@@ -821,6 +883,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
821883
"stp q30, q31, [%[c_ptr3]], #32\n"
822884

823885
#define GEMM_INT8_INT8_OUT \
886+
GEMM_TRANS_INT32_TO_FP32_N_Direction \
824887
GEMM_TRANS_INT32_TO_FP32 \
825888
GEMM_INT8_RELU \
826889
GEMM_INT8_RELU6 \
@@ -933,7 +996,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
933996
const float32_t* alpha,
934997
int is_relu,
935998
int k,
936-
int rem) {
999+
int rem,
1000+
int bias_direction) {
9371001
// clang-format off
9381002
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT
9391003
: [a_ptr] "+r"(a_ptr),
@@ -947,7 +1011,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9471011
[alpha] "r"(alpha),
9481012
[bias] "r"(bias),
9491013
[rem] "r"(rem),
950-
[scale] "r"(scale)
1014+
[scale] "r"(scale),
1015+
[bias_direction] "r"(bias_direction)
9511016
: "v0","v1","v2","v3","v4","v5","v6","v7","v8",
9521017
"v9","v10","v11","v12","v13","v14",
9531018
"v15","v16","v17","v18","v19","v20",
@@ -968,7 +1033,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9681033
const float32_t* alpha,
9691034
int is_relu,
9701035
int k,
971-
int rem) {
1036+
int rem,
1037+
int bias_direction) {
9721038
// clang-format off
9731039
float vmax[4] = {-127.0, -127.0, -127.0, -127.0};
9741040
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT
@@ -984,7 +1050,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9841050
[bias] "r"(bias),
9851051
[rem] "r"(rem),
9861052
[scale] "r"(scale),
987-
[vmax] "r"(vmax)
1053+
[vmax] "r"(vmax),
1054+
[bias_direction] "r"(bias_direction)
9881055
: "v0","v1","v2","v3","v4","v5","v6","v7",
9891056
"v8","v9","v10","v11","v12",
9901057
"v13","v14","v15","v16","v17",
@@ -1006,7 +1073,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
10061073
const float32_t* alpha,
10071074
int is_relu,
10081075
int k,
1009-
int rem) {
1076+
int rem,
1077+
int bias_direction) {
10101078
// clang-format off
10111079
float vmax[4] = {-127.0, -127.0, -127.0, -127.0};
10121080
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT
@@ -1022,7 +1090,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
10221090
[bias] "r"(bias),
10231091
[rem] "r"(rem),
10241092
[scale] "r"(scale),
1025-
[vmax] "r"(vmax)
1093+
[vmax] "r"(vmax),
1094+
[bias_direction] "r"(bias_direction)
10261095
: "v0","v1","v2","v3","v4","v5","v6","v7",
10271096
"v8","v9","v10","v11","v12",
10281097
"v13","v14","v15","v16","v17",
@@ -4016,8 +4085,43 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
40164085
"vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale */ \
40174086
"vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale */ \
40184087
"vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale */ \
4019-
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */
4088+
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */ \
4089+
"8: \n"
40204090

4091+
#define GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
4092+
"cmp %[bias_direction], #2\n" /* skip N_Direction */ \
4093+
"bne 7f\n" /* skip N_Direction */ \
4094+
/* write output */ \
4095+
"vld1.32 {d12-d13}, [%[scale]]!\n" /* load scale */ \
4096+
"vld1.32 {d14-d15}, [%[bias]]!\n" /* load bias */ \
4097+
"vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \
4098+
"vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \
4099+
"vmov.32 q8, q6\n" \
4100+
"vmov.32 q0, q6\n" \
4101+
"vmla.f32 q8, q10, q7\n" /* r00, mul scale */ \
4102+
"vmla.f32 q0, q12, q7\n" /* r10, mul scale */ \
4103+
"vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \
4104+
"vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \
4105+
"vdup.32 q2, d15[0]\n" \
4106+
"vdup.32 q4, d15[1]\n" \
4107+
"vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale */ \
4108+
"vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale */ \
4109+
"vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \
4110+
"vld1.32 {d14-d15}, [%[bias]]\n" /* load bias */ \
4111+
"vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \
4112+
"vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \
4113+
"vmov.32 q9, q6\n" \
4114+
"vmov.32 q1, q6\n" \
4115+
"vmla.f32 q9, q11, q7\n" /* r01, mul scale */ \
4116+
"vmla.f32 q1, q13, q7\n" /* r11, mul scale */ \
4117+
"vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \
4118+
"vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \
4119+
"vdup.32 q3, d15[0]\n" \
4120+
"vdup.32 q5, d15[1]\n" \
4121+
"vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale */ \
4122+
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */ \
4123+
"b 8f \n" \
4124+
"7: \n"
40214125

40224126
#define GEMM_INT8_RELU \
40234127
/* do relu */ \
@@ -4141,19 +4245,21 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
41414245
"vmul.f32 q5, q5, q11 \n" \
41424246
"9: \n"
41434247

4144-
#define GEMM_INT8_FP32_OUT \
4145-
GEMM_INT8_TRANS_INT32_TO_FP32 \
4146-
GEMM_INT8_RELU \
4147-
GEMM_INT8_RELU6 \
4148-
GEMM_INT8_LEAKY_RELU \
4149-
GEMM_INT8_HARD_SWISH \
4248+
#define GEMM_INT8_FP32_OUT \
4249+
GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
4250+
GEMM_INT8_TRANS_INT32_TO_FP32 \
4251+
GEMM_INT8_RELU \
4252+
GEMM_INT8_RELU6 \
4253+
GEMM_INT8_LEAKY_RELU \
4254+
GEMM_INT8_HARD_SWISH \
41504255
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \
41514256
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \
41524257
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \
41534258
"vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write r3, float32x4 x2 */
41544259

41554260

41564261
#define GEMM_INT8_INT8_OUT \
4262+
GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
41574263
GEMM_INT8_TRANS_INT32_TO_FP32 \
41584264
GEMM_INT8_RELU \
41594265
GEMM_INT8_RELU6 \
@@ -4257,7 +4363,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
42574363
const float32_t* alpha,
42584364
int is_relu,
42594365
int k,
4260-
int rem) {
4366+
int rem,
4367+
int bias_direction) {
42614368
float new_ptr[16] = {alpha[0],
42624369
alpha[1],
42634370
alpha[2],
@@ -4287,7 +4394,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
42874394
[bias] "r"(bias),
42884395
[alpha] "r"(new_ptr),
42894396
[rem] "r"(rem),
4290-
[scale] "r"(scale)
4397+
[scale] "r"(scale),
4398+
[bias_direction] "r"(bias_direction)
42914399
: "q0",
42924400
"q1",
42934401
"q2",
@@ -4320,7 +4428,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43204428
const float32_t* alpha,
43214429
int is_relu,
43224430
int k,
4323-
int rem) {
4431+
int rem,
4432+
int bias_direction) {
43244433
float new_ptr[16] = {alpha[0],
43254434
alpha[1],
43264435
alpha[2],
@@ -4350,7 +4459,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43504459
[alpha] "r"(new_ptr),
43514460
[bias] "r"(bias),
43524461
[rem] "r"(rem),
4353-
[scale] "r"(scale)
4462+
[scale] "r"(scale),
4463+
[bias_direction] "r"(bias_direction)
43544464
: "q0",
43554465
"q1",
43564466
"q2",
@@ -4384,7 +4494,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43844494
const float32_t* alpha,
43854495
int is_relu,
43864496
int k,
4387-
int rem) {
4497+
int rem,
4498+
int bias_direction) {
43884499
float new_ptr[16] = {alpha[0],
43894500
alpha[1],
43904501
alpha[2],
@@ -4511,24 +4622,27 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45114622
Dtype* tmp1 = nullptr;
45124623
Dtype* tmp2 = nullptr;
45134624
Dtype* tmp3 = nullptr;
4514-
float32_t scale_local[4] = {0, 0, 0, 0};
4515-
float32_t bias_local[4] = {0, 0, 0, 0};
4516-
if (is_bias) {
4517-
if (y + 4 <= M) {
4518-
bias_local[0] = bias[y];
4519-
bias_local[1] = bias[y + 1];
4520-
bias_local[2] = bias[y + 2];
4521-
bias_local[3] = bias[y + 3];
4522-
} else {
4523-
switch (M - y) {
4524-
case 3:
4525-
bias_local[2] = bias[y + 2];
4526-
case 2:
4527-
bias_local[1] = bias[y + 1];
4528-
case 1:
4529-
bias_local[0] = bias[y + 0];
4530-
default:
4531-
break;
4625+
float32_t scale_local[16] = {
4626+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
4627+
float32_t bias_local[16] = {0};
4628+
if (bias_direction != GemmNBias) {
4629+
if (is_bias) {
4630+
if (y + 4 <= M) {
4631+
bias_local[0] = bias[y];
4632+
bias_local[1] = bias[y + 1];
4633+
bias_local[2] = bias[y + 2];
4634+
bias_local[3] = bias[y + 3];
4635+
} else {
4636+
switch (M - y) {
4637+
case 3:
4638+
bias_local[2] = bias[y + 2];
4639+
case 2:
4640+
bias_local[1] = bias[y + 1];
4641+
case 1:
4642+
bias_local[0] = bias[y + 0];
4643+
default:
4644+
break;
4645+
}
45324646
}
45334647
}
45344648
}
@@ -4566,6 +4680,18 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45664680
const int8_t* a_ptr_l = A_packed + y * KUP;
45674681
const int8_t* b_ptr = b_pannel;
45684682
for (int xb = 0; xb < bblocks; xb++) {
4683+
if (bias_direction == GemmNBias) {
4684+
if (scale) {
4685+
for (int j = 0; j < NBLOCK_INT8_OTH; j++) {
4686+
scale_local[j] = scale[xb * NBLOCK_INT8_OTH + j + x0];
4687+
}
4688+
}
4689+
if (bias) {
4690+
for (int j = 0; j < NBLOCK_INT8_OTH; j++) {
4691+
bias_local[j] = bias[xb * NBLOCK_INT8_OTH + j + x0];
4692+
}
4693+
}
4694+
}
45694695
if (flag_rem && (xb == bblocks - 1)) {
45704696
tmp0 = c_ptr0;
45714697
tmp1 = c_ptr1;
@@ -4587,7 +4713,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45874713
alpha,
45884714
flag_act,
45894715
k,
4590-
k_rem);
4716+
k_rem,
4717+
bias_direction);
45914718
if (flag_rem && (xb == bblocks - 1)) {
45924719
for (int i = 0; i < n_rem; ++i) {
45934720
*(tmp0++) = out0[i];
@@ -7994,9 +8121,8 @@ GEMM_PREPACK_INT8(float_t);
79948121
GEMM_PREPACK_INT8(int32_t);
79958122

79968123
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
7997-
#define IN_PARAMS_NO_BIAS_DIRECTION \
7998-
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, \
7999-
scale, alpha, ctx
8124+
#define IN_PARAMS_NO_BIAS_DIRECTION \
8125+
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx
80008126
template <typename dtype>
80018127
void gemm_prepack_int8_nopack(const int8_t* A_packed,
80028128
const int8_t* B,

0 commit comments

Comments
 (0)