@@ -258,7 +258,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
258258 const float32_t * alpha,
259259 int is_relu,
260260 int k,
261- int rem);
261+ int rem,
262+ int bias_direction);
262263// clang-format off
263264#ifdef __aarch64__
264265#define GEMM_INT8_KERNEL \
@@ -802,9 +803,70 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
802803 " fmla v28.4s, v4.4s, v15.s[3]\n " /* 30, mul scale */ \
803804 " fmla v29.4s, v5.4s, v15.s[3]\n " /* 31, mul scale */ \
804805 " fmla v30.4s, v6.4s, v15.s[3]\n " /* 32, mul scale */ \
805- " fmla v31.4s, v7.4s, v15.s[3]\n " /* 33, mul scale */
806+ " fmla v31.4s, v7.4s, v15.s[3]\n " /* 33, mul scale */ \
807+ " 8: \n "
808+
809+ #define GEMM_TRANS_INT32_TO_FP32_N_Direction \
810+ " cmp %w[bias_direction], #2\n " /* skip N_Direction */ \
811+ " bne 7f\n " /* skip N_Direction */ \
812+ " ldp q8, q9, [%[bias]]\n " /* load bias */ \
813+ " ldp q10, q11, [%[bias], #32]\n " /* load bias */ \
814+ " ldp q12, q13, [%[scale]]\n " /* load scale */ \
815+ " ldp q14, q15, [%[scale], #32]\n " /* load scale */ \
816+ " scvtf v0.4s , v16.4s\n " /* 00, convert to fp32 */ \
817+ " scvtf v1.4s , v17.4s\n " /* 01, convert to fp32 */ \
818+ " scvtf v2.4s , v18.4s\n " /* 02, convert to fp32 */ \
819+ " scvtf v3.4s , v19.4s\n " /* 03, convert to fp32 */ \
820+ " scvtf v4.4s , v20.4s\n " /* 10, convert to fp32 */ \
821+ " scvtf v5.4s , v21.4s\n " /* 11, convert to fp32 */ \
822+ " scvtf v6.4s , v22.4s\n " /* 12, convert to fp32 */ \
823+ " scvtf v7.4s , v23.4s\n " /* 13, convert to fp32 */ \
824+ /* add bias */ \
825+ " mov v16.4s, v8.4s\n " \
826+ " mov v17.4s, v9.4s\n " \
827+ " mov v18.4s, v10.4s\n " \
828+ " mov v19.4s, v11.4s\n " \
829+ " mov v20.4s, v8.4s\n " \
830+ " mov v21.4s, v9.4s\n " \
831+ " mov v22.4s, v10.4s\n " \
832+ " mov v23.4s, v11.4s\n " \
833+ " fmla v16.4s, v0.4s, v12.4s\n " /* 00, mul scale */ \
834+ " fmla v17.4s, v1.4s, v13.4s\n " /* 01, mul scale */ \
835+ " fmla v18.4s, v2.4s, v14.4s\n " /* 02, mul scale */ \
836+ " fmla v19.4s, v3.4s, v15.4s\n " /* 03, mul scale */ \
837+ " fmla v20.4s, v4.4s, v12.4s\n " /* 10, mul scale */ \
838+ " fmla v21.4s, v5.4s, v13.4s\n " /* 11, mul scale */ \
839+ " fmla v22.4s, v6.4s, v14.4s\n " /* 12, mul scale */ \
840+ " fmla v23.4s, v7.4s, v15.4s\n " /* 13, mul scale */ \
841+ " scvtf v0.4s , v24.4s\n " /* 20, convert to fp32 */ \
842+ " scvtf v1.4s , v25.4s\n " /* 21, convert to fp32 */ \
843+ " scvtf v2.4s , v26.4s\n " /* 22, convert to fp32 */ \
844+ " scvtf v3.4s , v27.4s\n " /* 23, convert to fp32 */ \
845+ " scvtf v4.4s , v28.4s\n " /* 30, convert to fp32 */ \
846+ " scvtf v5.4s , v29.4s\n " /* 31, convert to fp32 */ \
847+ " scvtf v6.4s , v30.4s\n " /* 32, convert to fp32 */ \
848+ " scvtf v7.4s , v31.4s\n " /* 33, convert to fp32 */ \
849+ " mov v24.4s, v8.4s\n " \
850+ " mov v25.4s, v9.4s\n " \
851+ " mov v26.4s, v10.4s\n " \
852+ " mov v27.4s, v11.4s\n " \
853+ " mov v28.4s, v8.4s\n " \
854+ " mov v29.4s, v9.4s\n " \
855+ " mov v30.4s, v10.4s\n " \
856+ " mov v31.4s, v11.4s\n " \
857+ " fmla v24.4s, v0.4s, v12.4s\n " /* 20, mul scale */ \
858+ " fmla v25.4s, v1.4s, v13.4s\n " /* 21, mul scale */ \
859+ " fmla v26.4s, v2.4s, v14.4s\n " /* 22, mul scale */ \
860+ " fmla v27.4s, v3.4s, v15.4s\n " /* 23, mul scale */ \
861+ " fmla v28.4s, v4.4s, v12.4s\n " /* 30, mul scale */ \
862+ " fmla v29.4s, v5.4s, v13.4s\n " /* 31, mul scale */ \
863+ " fmla v30.4s, v6.4s, v14.4s\n " /* 32, mul scale */ \
864+ " fmla v31.4s, v7.4s, v15.4s\n " /* 33, mul scale */ \
865+ " b 8f \n " \
866+ " 7: \n "
806867
807868#define GEMM_INT8_FP32_OUT \
869+ GEMM_TRANS_INT32_TO_FP32_N_Direction \
808870 GEMM_TRANS_INT32_TO_FP32 \
809871 GEMM_INT8_RELU \
810872 GEMM_INT8_RELU6 \
@@ -821,6 +883,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
821883 " stp q30, q31, [%[c_ptr3]], #32\n "
822884
823885#define GEMM_INT8_INT8_OUT \
886+ GEMM_TRANS_INT32_TO_FP32_N_Direction \
824887 GEMM_TRANS_INT32_TO_FP32 \
825888 GEMM_INT8_RELU \
826889 GEMM_INT8_RELU6 \
@@ -933,7 +996,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
933996 const float32_t * alpha,
934997 int is_relu,
935998 int k,
936- int rem) {
999+ int rem,
1000+ int bias_direction) {
9371001 // clang-format off
9381002 asm volatile (GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT
9391003 : [a_ptr] " +r" (a_ptr),
@@ -947,7 +1011,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9471011 [alpha] " r" (alpha),
9481012 [bias] " r" (bias),
9491013 [rem] " r" (rem),
950- [scale] " r" (scale)
1014+ [scale] " r" (scale),
1015+ [bias_direction] " r" (bias_direction)
9511016 : " v0" ," v1" ," v2" ," v3" ," v4" ," v5" ," v6" ," v7" ," v8" ,
9521017 " v9" ," v10" ," v11" ," v12" ," v13" ," v14" ,
9531018 " v15" ," v16" ," v17" ," v18" ," v19" ," v20" ,
@@ -968,7 +1033,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9681033 const float32_t * alpha,
9691034 int is_relu,
9701035 int k,
971- int rem) {
1036+ int rem,
1037+ int bias_direction) {
9721038 // clang-format off
9731039 float vmax[4 ] = {-127.0 , -127.0 , -127.0 , -127.0 };
9741040 asm volatile (GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT
@@ -984,7 +1050,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
9841050 [bias] " r" (bias),
9851051 [rem] " r" (rem),
9861052 [scale] " r" (scale),
987- [vmax] " r" (vmax)
1053+ [vmax] " r" (vmax),
1054+ [bias_direction] " r" (bias_direction)
9881055 : " v0" ," v1" ," v2" ," v3" ," v4" ," v5" ," v6" ," v7" ,
9891056 " v8" ," v9" ," v10" ," v11" ," v12" ,
9901057 " v13" ," v14" ," v15" ," v16" ," v17" ,
@@ -1006,7 +1073,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
10061073 const float32_t * alpha,
10071074 int is_relu,
10081075 int k,
1009- int rem) {
1076+ int rem,
1077+ int bias_direction) {
10101078 // clang-format off
10111079 float vmax[4 ] = {-127.0 , -127.0 , -127.0 , -127.0 };
10121080 asm volatile (GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT
@@ -1022,7 +1090,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
10221090 [bias] " r" (bias),
10231091 [rem] " r" (rem),
10241092 [scale] " r" (scale),
1025- [vmax] " r" (vmax)
1093+ [vmax] " r" (vmax),
1094+ [bias_direction] " r" (bias_direction)
10261095 : " v0" ," v1" ," v2" ," v3" ," v4" ," v5" ," v6" ," v7" ,
10271096 " v8" ," v9" ," v10" ," v11" ," v12" ,
10281097 " v13" ," v14" ," v15" ," v16" ," v17" ,
@@ -4016,8 +4085,43 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
40164085 " vmla.f32 q2, q10, d13[0]\n " /* r20, mul scale */ \
40174086 " vmla.f32 q3, q11, d13[0]\n " /* r21, mul scale */ \
40184087 " vmla.f32 q4, q12, d13[1]\n " /* r30, mul scale */ \
4019- " vmla.f32 q5, q13, d13[1]\n " /* r31, mul scale */
4088+ " vmla.f32 q5, q13, d13[1]\n " /* r31, mul scale */ \
4089+ " 8: \n "
40204090
4091+ #define GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
4092+ " cmp %[bias_direction], #2\n " /* skip N_Direction */ \
4093+ " bne 7f\n " /* skip N_Direction */ \
4094+ /* write output */ \
4095+ " vld1.32 {d12-d13}, [%[scale]]!\n " /* load scale */ \
4096+ " vld1.32 {d14-d15}, [%[bias]]!\n " /* load bias */ \
4097+ " vcvt.f32.s32 q10, q8\n " /* r00, cvt int32 to fp32*/ \
4098+ " vcvt.f32.s32 q12, q0\n " /* r10, cvt int32 to fp32*/ \
4099+ " vmov.32 q8, q6\n " \
4100+ " vmov.32 q0, q6\n " \
4101+ " vmla.f32 q8, q10, q7\n " /* r00, mul scale */ \
4102+ " vmla.f32 q0, q12, q7\n " /* r10, mul scale */ \
4103+ " vcvt.f32.s32 q10, q2\n " /* r20, cvt int32 to fp32*/ \
4104+ " vcvt.f32.s32 q12, q4\n " /* r30, cvt int32 to fp32*/ \
4105+ " vdup.32 q2, d15[0]\n " \
4106+ " vdup.32 q4, d15[1]\n " \
4107+ " vmla.f32 q2, q10, d13[0]\n " /* r20, mul scale */ \
4108+ " vmla.f32 q4, q12, d13[1]\n " /* r30, mul scale */ \
4109+ " vld1.32 {d12-d13}, [%[scale]]\n " /* load scale */ \
4110+ " vld1.32 {d14-d15}, [%[bias]]\n " /* load bias */ \
4111+ " vcvt.f32.s32 q11, q9\n " /* r01, cvt int32 to fp32*/ \
4112+ " vcvt.f32.s32 q13, q1\n " /* r11, cvt int32 to fp32*/ \
4113+ " vmov.32 q9, q6\n " \
4114+ " vmov.32 q1, q6\n " \
4115+ " vmla.f32 q9, q11, q7\n " /* r01, mul scale */ \
4116+ " vmla.f32 q1, q13, q7\n " /* r11, mul scale */ \
4117+ " vcvt.f32.s32 q11, q3\n " /* r21, cvt int32 to fp32*/ \
4118+ " vcvt.f32.s32 q13, q5\n " /* r31, cvt int32 to fp32*/ \
4119+ " vdup.32 q3, d15[0]\n " \
4120+ " vdup.32 q5, d15[1]\n " \
4121+ " vmla.f32 q3, q11, d13[0]\n " /* r21, mul scale */ \
4122+ " vmla.f32 q5, q13, d13[1]\n " /* r31, mul scale */ \
4123+ " b 8f \n " \
4124+ " 7: \n "
40214125
40224126#define GEMM_INT8_RELU \
40234127 /* do relu */ \
@@ -4141,19 +4245,21 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
41414245 " vmul.f32 q5, q5, q11 \n " \
41424246 " 9: \n "
41434247
4144- #define GEMM_INT8_FP32_OUT \
4145- GEMM_INT8_TRANS_INT32_TO_FP32 \
4146- GEMM_INT8_RELU \
4147- GEMM_INT8_RELU6 \
4148- GEMM_INT8_LEAKY_RELU \
4149- GEMM_INT8_HARD_SWISH \
4248+ #define GEMM_INT8_FP32_OUT \
4249+ GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
4250+ GEMM_INT8_TRANS_INT32_TO_FP32 \
4251+ GEMM_INT8_RELU \
4252+ GEMM_INT8_RELU6 \
4253+ GEMM_INT8_LEAKY_RELU \
4254+ GEMM_INT8_HARD_SWISH \
41504255 " vst1.32 {d16-d19}, [%[c_ptr0]]!\n " /* write r0, float32x4 x2 */ \
41514256 " vst1.32 {d0-d3}, [%[c_ptr1]]!\n " /* write r1, float32x4 x2 */ \
41524257 " vst1.32 {d4-d7}, [%[c_ptr2]]!\n " /* write r2, float32x4 x2 */ \
41534258 " vst1.32 {d8-d11}, [%[c_ptr3]]!\n " /* write r3, float32x4 x2 */
41544259
41554260
41564261#define GEMM_INT8_INT8_OUT \
4262+ GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
41574263 GEMM_INT8_TRANS_INT32_TO_FP32 \
41584264 GEMM_INT8_RELU \
41594265 GEMM_INT8_RELU6 \
@@ -4257,7 +4363,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
42574363 const float32_t * alpha,
42584364 int is_relu,
42594365 int k,
4260- int rem) {
4366+ int rem,
4367+ int bias_direction) {
42614368 float new_ptr[16 ] = {alpha[0 ],
42624369 alpha[1 ],
42634370 alpha[2 ],
@@ -4287,7 +4394,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
42874394 [bias] " r" (bias),
42884395 [alpha] " r" (new_ptr),
42894396 [rem] " r" (rem),
4290- [scale] " r" (scale)
4397+ [scale] " r" (scale),
4398+ [bias_direction] " r" (bias_direction)
42914399 : " q0" ,
42924400 " q1" ,
42934401 " q2" ,
@@ -4320,7 +4428,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43204428 const float32_t * alpha,
43214429 int is_relu,
43224430 int k,
4323- int rem) {
4431+ int rem,
4432+ int bias_direction) {
43244433 float new_ptr[16 ] = {alpha[0 ],
43254434 alpha[1 ],
43264435 alpha[2 ],
@@ -4350,7 +4459,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43504459 [alpha] " r" (new_ptr),
43514460 [bias] " r" (bias),
43524461 [rem] " r" (rem),
4353- [scale] " r" (scale)
4462+ [scale] " r" (scale),
4463+ [bias_direction] " r" (bias_direction)
43544464 : " q0" ,
43554465 " q1" ,
43564466 " q2" ,
@@ -4384,7 +4494,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
43844494 const float32_t * alpha,
43854495 int is_relu,
43864496 int k,
4387- int rem) {
4497+ int rem,
4498+ int bias_direction) {
43884499 float new_ptr[16 ] = {alpha[0 ],
43894500 alpha[1 ],
43904501 alpha[2 ],
@@ -4511,24 +4622,27 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45114622 Dtype* tmp1 = nullptr ;
45124623 Dtype* tmp2 = nullptr ;
45134624 Dtype* tmp3 = nullptr ;
4514- float32_t scale_local[4 ] = {0 , 0 , 0 , 0 };
4515- float32_t bias_local[4 ] = {0 , 0 , 0 , 0 };
4516- if (is_bias) {
4517- if (y + 4 <= M) {
4518- bias_local[0 ] = bias[y];
4519- bias_local[1 ] = bias[y + 1 ];
4520- bias_local[2 ] = bias[y + 2 ];
4521- bias_local[3 ] = bias[y + 3 ];
4522- } else {
4523- switch (M - y) {
4524- case 3 :
4525- bias_local[2 ] = bias[y + 2 ];
4526- case 2 :
4527- bias_local[1 ] = bias[y + 1 ];
4528- case 1 :
4529- bias_local[0 ] = bias[y + 0 ];
4530- default :
4531- break ;
4625+ float32_t scale_local[16 ] = {
4626+ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 };
4627+ float32_t bias_local[16 ] = {0 };
4628+ if (bias_direction != GemmNBias) {
4629+ if (is_bias) {
4630+ if (y + 4 <= M) {
4631+ bias_local[0 ] = bias[y];
4632+ bias_local[1 ] = bias[y + 1 ];
4633+ bias_local[2 ] = bias[y + 2 ];
4634+ bias_local[3 ] = bias[y + 3 ];
4635+ } else {
4636+ switch (M - y) {
4637+ case 3 :
4638+ bias_local[2 ] = bias[y + 2 ];
4639+ case 2 :
4640+ bias_local[1 ] = bias[y + 1 ];
4641+ case 1 :
4642+ bias_local[0 ] = bias[y + 0 ];
4643+ default :
4644+ break ;
4645+ }
45324646 }
45334647 }
45344648 }
@@ -4566,6 +4680,18 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45664680 const int8_t * a_ptr_l = A_packed + y * KUP;
45674681 const int8_t * b_ptr = b_pannel;
45684682 for (int xb = 0 ; xb < bblocks; xb++) {
4683+ if (bias_direction == GemmNBias) {
4684+ if (scale) {
4685+ for (int j = 0 ; j < NBLOCK_INT8_OTH; j++) {
4686+ scale_local[j] = scale[xb * NBLOCK_INT8_OTH + j + x0];
4687+ }
4688+ }
4689+ if (bias) {
4690+ for (int j = 0 ; j < NBLOCK_INT8_OTH; j++) {
4691+ bias_local[j] = bias[xb * NBLOCK_INT8_OTH + j + x0];
4692+ }
4693+ }
4694+ }
45694695 if (flag_rem && (xb == bblocks - 1 )) {
45704696 tmp0 = c_ptr0;
45714697 tmp1 = c_ptr1;
@@ -4587,7 +4713,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
45874713 alpha,
45884714 flag_act,
45894715 k,
4590- k_rem);
4716+ k_rem,
4717+ bias_direction);
45914718 if (flag_rem && (xb == bblocks - 1 )) {
45924719 for (int i = 0 ; i < n_rem; ++i) {
45934720 *(tmp0++) = out0[i];
@@ -7994,9 +8121,8 @@ GEMM_PREPACK_INT8(float_t);
79948121GEMM_PREPACK_INT8 (int32_t );
79958122
79968123#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
7997- #define IN_PARAMS_NO_BIAS_DIRECTION \
7998- A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, \
7999- scale, alpha, ctx
8124+ #define IN_PARAMS_NO_BIAS_DIRECTION \
8125+ A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx
80008126template <typename dtype>
80018127void gemm_prepack_int8_nopack (const int8_t * A_packed,
80028128 const int8_t * B,
0 commit comments