diff --git a/lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc b/lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc index 50335a8f461..ad3e5948c25 100644 --- a/lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc +++ b/lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc @@ -28,7 +28,13 @@ namespace fp16 { const int OUT_C_BLOCK = 8; const int OUT_H_BLOCK = 2; +#ifdef __aarch64__ const int OUT_W_BLOCK = 8; +#else +const int OUT_W_BLOCK = 4; +#endif + +#ifdef __aarch64__ #define COMPUT_INIT \ float16_t* ptr_out0 = pre_out0; \ float16_t* ptr_out1 = pre_out1; \ @@ -46,6 +52,16 @@ const int OUT_W_BLOCK = 8; const float16_t* r2 = inr2; \ const float16_t* r3 = inr3; \ const float16_t* r4 = inr4; +#else +#define COMPUT_INIT \ + float16_t* ptr_out0 = pre_out0; \ + float16_t* ptr_out1 = pre_out1; \ + const float16_t* r0 = inr0; \ + const float16_t* r1 = inr1; \ + const float16_t* r2 = inr2; \ + const float16_t* r3 = inr3; \ + const float16_t* r4 = inr4; +#endif size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, ARMContext* ctx) { @@ -562,6 +578,344 @@ size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, "add %[ptr_out0], %[ptr_out0], #0x80\n"\ "bne 1b\n" #else + +#define INIT_FIRST \ + "2:\n" \ + "vld1.16 {d10-d13}, [%[wc0]]! @ load w0, w1\n" \ + "vld1.16 {d0-d2}, [%[r0]] @ load r0\n" \ + "add %[r0], %[r0], #16\n" \ + "vmul.f16 q8, q5, d0[0] @ w0 * inr00\n" \ + "vmul.f16 q9, q5, d0[2] @ w0 * inr02\n" \ + "vmul.f16 q10, q5, d1[0] @ w0 * inr04\n" \ + "vmul.f16 q11, q5, d1[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vld1.16 {d3-d5}, [%[r2]] @ load r2\n" \ + "add %[r2], %[r2], #16\n" \ + "vmul.f16 q12, q5, d3[0] @ w0 * inr20\n" \ + "vmul.f16 q13, q5, d3[2] @ w0 * inr22\n" \ + "vld1.16 {d14-d15}, [%[wc0]]! @ load w2\n" \ + "vmul.f16 q14, q5, d4[0] @ w0 * inr24\n" \ + "vmul.f16 q15, q5, d4[2] @ w0 * inr26\n" + +#define INIT \ + "2:\n" \ + "vld1.16 {d10-d13}, [%[wc0]]! @ load w0, w1\n" \ + "vld1.16 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" \ + "vld1.16 {d0-d2}, [%[r0]] @ load r0\n" \ + "add %[r0], %[r0], #16\n" \ + "vld1.16 {d14-d15}, [%[wc0]]! @ load w2\n" \ + "vmla.f16 q8, q5, d0[0] @ w0 * inr00\n" \ + "vld1.16 {d20-d23}, [%[ptr_out0]] @ load outr0\n" \ + "sub %[ptr_out0], %[ptr_out0], #32\n" \ + "vmla.f16 q9, q5, d0[2] @ w0 * inr02\n" \ + "vmla.f16 q10, q5, d1[0] @ w0 * inr04\n" \ + "vld1.16 {d24-d27}, [%[ptr_out1]]! @ load outr0\n" \ + "vmla.f16 q11, q5, d1[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vld1.16 {d3-d5}, [%[r2]] @ load r2\n" \ + "add %[r2], %[r2], #16\n" \ + "vmla.f16 q12, q5, d3[0] @ w0 * inr20\n" \ + "vmla.f16 q13, q5, d3[2] @ w0 * inr22\n" \ + "vld1.16 {d28-d31}, [%[ptr_out1]] @ load outr0\n" \ + "sub %[ptr_out1], %[ptr_out1], #32\n" \ + "vmla.f16 q14, q5, d4[0] @ w0 * inr24\n" \ + "vmla.f16 q15, q5, d4[2] @ w0 * inr26\n" + +#define COMPUTE \ + /* r0-1 */ \ + "vld1.16 {d6-d8}, [%[r1]] @ load r1\n" \ + "add %[r1], %[r1], #16\n" \ + "vmla.f16 q8, q6, d0[1] @ w0 * inr00\n" \ + "vmla.f16 q9, q6, d0[3] @ w0 * inr02\n" \ + "vmla.f16 q10, q6, d1[1] @ w0 * inr04\n" \ + "vmla.f16 q11, q6, d1[3] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r2-1 */ \ + "vmla.f16 q12, q6, d3[1] @ w0 * inr20\n" \ + "vmla.f16 q13, q6, d3[3] @ w0 * inr22\n" \ + "vmla.f16 q14, q6, d4[1] @ w0 * inr24\n" \ + "vmla.f16 q15, q6, d4[3] @ w0 * inr26\n" \ + "vld1.16 {d10-d13}, [%[wc0]]! @ load w5, to q7\n" /* mul r1, with*/ \ + /* r0-2 */ \ + "vmla.f16 q8, q7, d0[2] @ w0 * inr00\n" \ + "vmla.f16 q9, q7, d1[0] @ w0 * inr02\n" \ + "vmla.f16 q10, q7, d1[2] @ w0 * inr04\n" \ + "vmla.f16 q11, q7, d2[0] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r2-2 */ \ + "vmla.f16 q12, q7, d3[2] @ w0 * inr20\n" \ + "vmla.f16 q13, q7, d4[0] @ w0 * inr22\n" \ + "vmla.f16 q14, q7, d4[2] @ w0 * inr24\n" \ + "vmla.f16 q15, q7, d5[0] @ w0 * inr26\n" \ + "vld1.16 {d14-d15}, [%[wc0]]! @ load w5, to q7\n" /* mul r1, with*/ \ + /* r1-0 */ \ + "vmla.f16 q8, q5, d6[0] @ w0 * inr00\n" \ + "vmla.f16 q9, q5, d6[2] @ w0 * inr02\n" \ + "vmla.f16 q10, q5, d7[0] @ w0 * inr04\n" \ + "vmla.f16 q11, q5, d7[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vld1.16 {d0-d2}, [%[r3]] @ load r1\n" \ + "add %[r3], %[r3], #16\n" \ + /* r1-1 */ \ + "vmla.f16 q8, q6, d6[1] @ w0 * inr00\n" \ + "vmla.f16 q9, q6, d6[3] @ w0 * inr02\n" \ + "vmla.f16 q10, q6, d7[1] @ w0 * inr04\n" \ + "vmla.f16 q11, q6, d7[3] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r1-2 */ \ + "vmla.f16 q8, q7, d6[2] @ w0 * inr00\n" \ + "vmla.f16 q9, q7, d7[0] @ w0 * inr02\n" \ + "vmla.f16 q10, q7, d7[2] @ w0 * inr04\n" \ + "vmov d6, d8 \n" \ + "vmla.f16 q11, q7, d6[0] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r3-0 */ \ + "vmla.f16 q12, q5, d0[0] @ w0 * inr00\n" \ + "vmla.f16 q13, q5, d0[2] @ w0 * inr02\n" \ + "vmla.f16 q14, q5, d1[0] @ w0 * inr04\n" \ + "vmla.f16 q15, q5, d1[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r3-1 */ \ + "vmla.f16 q12, q6, d0[1] @ w0 * inr00\n" \ + "vmla.f16 q13, q6, d0[3] @ w0 * inr02\n" \ + "vmla.f16 q14, q6, d1[1] @ w0 * inr04\n" \ + "vmla.f16 q15, q6, d1[3] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vld1.16 {d10-d13}, [%[wc0]]! @ load w0, w1\n" \ + /* r3-2 */ \ + "vmla.f16 q12, q7, d0[2] @ w0 * inr00\n" \ + "vmla.f16 q13, q7, d1[0] @ w0 * inr02\n" \ + "vmla.f16 q14, q7, d1[2] @ w0 * inr04\n" \ + "vmla.f16 q15, q7, d2[0] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vld1.16 {d6-d8}, [%[r4]] @ load r3\n" \ + "add %[r4], %[r4], #16\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, to q7\n" /* mul r1, with*/ \ + "sub %[wc0], %[wc0], #144\n" \ + /* r2-0 */ \ + "vmla.f16 q8, q5, d3[0] @ w0 * inr00\n" \ + "vmla.f16 q9, q5, d3[2] @ w0 * inr02\n" \ + "vmla.f16 q10, q5, d4[0] @ w0 * inr04\n" \ + "vmla.f16 q11, q5, d4[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r2-1 */ \ + "vmla.f16 q8, q6, d3[1] @ w0 * inr00\n" \ + "vmla.f16 q9, q6, d3[3] @ w0 * inr02\n" \ + "vmla.f16 q10, q6, d4[1] @ w0 * inr04\n" \ + "vmla.f16 q11, q6, d4[3] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r2-2 */ \ + "vmla.f16 q8, q7, d3[2] @ w0 * inr00\n" \ + "vmla.f16 q9, q7, d4[0] @ w0 * inr02\n" \ + "vmla.f16 q10, q7, d4[2] @ w0 * inr04\n" \ + "vmla.f16 q11, q7, d5[0] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r4-0 */ \ + "vmla.f16 q12, q5, d6[0] @ w0 * inr00\n" \ + "vmla.f16 q13, q5, d6[2] @ w0 * inr02\n" \ + "vmla.f16 q14, q5, d7[0] @ w0 * inr04\n" \ + "vmla.f16 q15, q5, d7[2] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r4-1 */ \ + "vmla.f16 q12, q6, d6[1] @ w0 * inr00\n" \ + "vmla.f16 q13, q6, d6[3] @ w0 * inr02\n" \ + "vmla.f16 q14, q6, d7[1] @ w0 * inr04\n" \ + "vmla.f16 q15, q6, d7[3] @ w0 * inr06\n" /* mul r0, with w0*/ \ + /* r4-2 */ \ + "vmla.f16 q12, q7, d6[2] @ w0 * inr00\n" \ + "vmla.f16 q13, q7, d7[0] @ w0 * inr02\n" \ + "vmla.f16 q14, q7, d7[2] @ w0 * inr04\n" \ + "vmov d6, d8 \n" \ + "vmla.f16 q15, q7, d6[0] @ w0 * inr06\n" /* mul r0, with w0*/ \ + "vst1.16 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" \ + "vst1.16 {d20-d23}, [%[ptr_out0]]! @ load outr0\n" \ + "vst1.16 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" \ + "vst1.16 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" \ + "subs %[cnt], %[cnt], #1\n" \ + "bne 2b\n" +#define ASM_PARAM \ + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), \ + [r2] "+r"(r2), [r3] "+r"(r3), [r4] "+r"(r4), \ + [wc0] "+r"(wc0), \ + [ptr_out0] "+r"(ptr_out0), \ + [ptr_out1] "+r"(ptr_out1) \ + : \ + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", \ + "q5", "q6", "q7", "q8", "q9", "q10", \ + "q11", "q12", "q13", "q14", "q15" + +#define COMPUTE_C3 \ + "1: \n" \ + "vld1.16 {d0-d3}, [%[r0]]! @ load q0, q1\n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + "vld1.16 {d4-d7}, [%[r0]]! @ load q2, q3\n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + "vld1.16 {d8}, [%[r0]] @ load q4, q5\n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0*/ \ + /* line 0 0 c0*/ \ + "vmul.f16 q15, q9, d6[0] @ mul \n" \ + "vmul.f16 q12, q9, d0[0] @ mul \n" \ + "vmul.f16 q13, q9, d2[0] @ mul \n" \ + "vmul.f16 q14, q9, d4[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d6[1] @ mul \n" \ + "vmla.f16 q12, q10, d0[1] @ mul \n" \ + "vmla.f16 q13, q10, d2[1] @ mul \n" \ + "vmla.f16 q14, q10, d4[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d6[2] @ mul \n" \ + "vmla.f16 q12, q11, d0[2] @ mul \n" \ + "vmla.f16 q13, q11, d2[2] @ mul \n" \ + "vmla.f16 q14, q11, d4[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 01 c0*/ \ + "vmla.f16 q15, q9, d7[0] @ mul \n" \ + "vmla.f16 q12, q9, d1[0] @ mul \n" \ + "vmla.f16 q13, q9, d3[0] @ mul \n" \ + "vmla.f16 q14, q9, d5[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d7[1] @ mul \n" \ + "vmla.f16 q12, q10, d1[1] @ mul \n" \ + "vmla.f16 q13, q10, d3[1] @ mul \n" \ + "vmla.f16 q14, q10, d5[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d7[2] @ mul \n" \ + "vmla.f16 q12, q11, d1[2] @ mul \n" \ + "vmla.f16 q13, q11, d3[2] @ mul \n" \ + "vmla.f16 q14, q11, d5[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 02 c0*/ \ + "vmov d0, d8 \n" \ + "vmla.f16 q15, q9, d0[0] @ mul \n" \ + "vmla.f16 q12, q9, d2[0] @ mul \n" \ + "vmla.f16 q13, q9, d4[0] @ mul \n" \ + "vmla.f16 q14, q9, d6[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d0[1] @ mul \n" \ + "vmla.f16 q12, q10, d2[1] @ mul \n" \ + "vmla.f16 q13, q10, d4[1] @ mul \n" \ + "vmla.f16 q14, q10, d6[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d0[2] @ mul \n" \ + "vmla.f16 q12, q11, d2[2] @ mul \n" \ + "vmla.f16 q13, q11, d4[2] @ mul \n" \ + "vmla.f16 q14, q11, d6[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /*i1*/\ + "vld1.16 {d0-d3}, [%[r1]]! @ load q0, q1\n" \ + "vld1.16 {d4-d7}, [%[r1]]! @ load q2, q3\n" \ + "vld1.16 {d8}, [%[r1]] @ load q4, q5\n" \ + /* line 0*/ \ + /* line 0 0 c0*/ \ + "vmla.f16 q15, q9, d6[0] @ mul \n" \ + "vmla.f16 q12, q9, d0[0] @ mul \n" \ + "vmla.f16 q13, q9, d2[0] @ mul \n" \ + "vmla.f16 q14, q9, d4[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d6[1] @ mul \n" \ + "vmla.f16 q12, q10, d0[1] @ mul \n" \ + "vmla.f16 q13, q10, d2[1] @ mul \n" \ + "vmla.f16 q14, q10, d4[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d6[2] @ mul \n" \ + "vmla.f16 q12, q11, d0[2] @ mul \n" \ + "vmla.f16 q13, q11, d2[2] @ mul \n" \ + "vmla.f16 q14, q11, d4[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 01 c0*/ \ + "vmla.f16 q15, q9, d7[0] @ mul \n" \ + "vmla.f16 q12, q9, d1[0] @ mul \n" \ + "vmla.f16 q13, q9, d3[0] @ mul \n" \ + "vmla.f16 q14, q9, d5[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d7[1] @ mul \n" \ + "vmla.f16 q12, q10, d1[1] @ mul \n" \ + "vmla.f16 q13, q10, d3[1] @ mul \n" \ + "vmla.f16 q14, q10, d5[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d7[2] @ mul \n" \ + "vmla.f16 q12, q11, d1[2] @ mul \n" \ + "vmla.f16 q13, q11, d3[2] @ mul \n" \ + "vmla.f16 q14, q11, d5[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 02 c0*/ \ + "vmov d0, d8 \n" \ + "vmla.f16 q15, q9, d0[0] @ mul \n" \ + "vmla.f16 q12, q9, d2[0] @ mul \n" \ + "vmla.f16 q13, q9, d4[0] @ mul \n" \ + "vmla.f16 q14, q9, d6[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d0[1] @ mul \n" \ + "vmla.f16 q12, q10, d2[1] @ mul \n" \ + "vmla.f16 q13, q10, d4[1] @ mul \n" \ + "vmla.f16 q14, q10, d6[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d0[2] @ mul \n" \ + "vmla.f16 q12, q11, d2[2] @ mul \n" \ + "vmla.f16 q13, q11, d4[2] @ mul \n" \ + "vmla.f16 q14, q11, d6[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /*i2*/\ + "vld1.16 {d0-d3}, [%[r2]]! @ load q0, q1\n" \ + "vld1.16 {d4-d7}, [%[r2]]! @ load q2, q3\n" \ + "vld1.16 {d8}, [%[r2]] @ load q4, q5\n" \ + /* line 0*/ \ + /* line 0 0 c0*/ \ + "vmla.f16 q15, q9, d6[0] @ mul \n" \ + "vmla.f16 q12, q9, d0[0] @ mul \n" \ + "vmla.f16 q13, q9, d2[0] @ mul \n" \ + "vmla.f16 q14, q9, d4[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d6[1] @ mul \n" \ + "vmla.f16 q12, q10, d0[1] @ mul \n" \ + "vmla.f16 q13, q10, d2[1] @ mul \n" \ + "vmla.f16 q14, q10, d4[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d6[2] @ mul \n" \ + "vmla.f16 q12, q11, d0[2] @ mul \n" \ + "vmla.f16 q13, q11, d2[2] @ mul \n" \ + "vmla.f16 q14, q11, d4[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 01 c0*/ \ + "vmla.f16 q15, q9, d7[0] @ mul \n" \ + "vmla.f16 q12, q9, d1[0] @ mul \n" \ + "vmla.f16 q13, q9, d3[0] @ mul \n" \ + "vmla.f16 q14, q9, d5[0] @ mul \n" \ + "vld1.16 {d18-d19}, [%[wc0]]! @ load w0, w1\n" \ + /* c1*/\ + "vmla.f16 q15, q10, d7[1] @ mul \n" \ + "vmla.f16 q12, q10, d1[1] @ mul \n" \ + "vmla.f16 q13, q10, d3[1] @ mul \n" \ + "vmla.f16 q14, q10, d5[1] @ mul \n" \ + "vld1.16 {d20-d21}, [%[wc0]]! @ load w2, w3\n" \ + /* c2*/\ + "vmla.f16 q15, q11, d7[2] @ mul \n" \ + "vmla.f16 q12, q11, d1[2] @ mul \n" \ + "vmla.f16 q13, q11, d3[2] @ mul \n" \ + "vmla.f16 q14, q11, d5[2] @ mul \n" \ + "vld1.16 {d22-d23}, [%[wc0]]! @ load w2, w3\n" \ + /* line 0 02 c0*/ \ + "vmov d0, d8 \n" \ + "vmla.f16 q15, q9, d0[0] @ mul \n" \ + "vmla.f16 q12, q9, d2[0] @ mul \n" \ + "vmla.f16 q13, q9, d4[0] @ mul \n" \ + "vmla.f16 q14, q9, d6[0] @ mul \n" \ + /* c1*/\ + "vmla.f16 q15, q10, d0[1] @ mul \n" \ + "vmla.f16 q12, q10, d2[1] @ mul \n" \ + "vmla.f16 q13, q10, d4[1] @ mul \n" \ + "vmla.f16 q14, q10, d6[1] @ mul \n" \ + /* c2*/\ + "vmla.f16 q15, q11, d0[2] @ mul \n" \ + "vmla.f16 q12, q11, d2[2] @ mul \n" \ + "vmla.f16 q13, q11, d4[2] @ mul \n" \ + "vmla.f16 q14, q11, d6[2] @ mul \n" \ + "sub %[wc0], %[wc0], #432\n" \ + "vst1.16 {d24-d27}, [%[ptr_out0]]! \n" \ + "vst1.16 {d28-d31}, [%[ptr_out0]]! \n" \ + "subs %[cnt], #1\n" \ + "bne 1b\n" #endif // clang-format on void conv_3x3s2_direct_fp16_c3(const float16_t* i_data, @@ -590,7 +944,7 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data, const float16_t* bias, const operators::ConvParam& param, ARMContext* ctx) { - if (ic == 3 && (oc % 4 == 0)) { + if (ic == 3 && (oc % 8 == 0)) { conv_3x3s2_direct_fp16_c3( i_data, o_data, bs, oc, oh, ow, ic, ih, win, weights, bias, param, ctx); return; @@ -615,7 +969,11 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data, int ws = -pad_w; int we = ws + win_round; +#ifdef __aarch64__ int w_loop = wout_round >> 3; +#else + int w_loop = wout_round >> 2; +#endif int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; @@ -685,7 +1043,6 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data, float16_t* pre_out0 = pre_out + hk * out_row_stride; float16_t* pre_out1 = pre_out0 + out_row_stride; -#ifdef __aarch64__ // first if (1) { COMPUT_INIT @@ -711,8 +1068,6 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data, inr3 += win_round; inr4 += win_round; } -#else // not __aarch64__ -#endif // __aarch64__ block_inr0 = block_inr4; block_inr1 = block_inr0 + in_len; block_inr2 = block_inr1 + in_len; @@ -777,7 +1132,11 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data, int ws = -pad_w; int we = ws + win_round; +#ifdef __aarch64__ int w_loop = wout_round >> 3; +#else + int w_loop = wout_round >> 2; +#endif int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; @@ -813,8 +1172,6 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data, const float16_t* cblock_inr0 = pre_din; const float16_t* cblock_inr1 = cblock_inr0 + in_len; const float16_t* cblock_inr2 = cblock_inr1 + in_len; - const float16_t* cblock_inr3 = cblock_inr2 + in_len; - const float16_t* cblock_inr4 = cblock_inr3 + in_len; LITE_PARALLEL_COMMON_BEGIN(c, tid, c_round_down, 0, OUT_C_BLOCK) { #ifdef LITE_USE_THREAD_POOL @@ -828,8 +1185,6 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data, const float16_t* block_inr0 = cblock_inr0; const float16_t* block_inr1 = cblock_inr1; const float16_t* block_inr2 = cblock_inr2; - const float16_t* block_inr3 = cblock_inr3; - const float16_t* block_inr4 = cblock_inr4; const float16_t* weight_c = weights + c * w_stride; const float16_t* bias_ptr = ptr_zero; @@ -899,6 +1254,34 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data, "v30", "v31"); #else // not __aarch64__ + int cnt = w_loop; + asm volatile(COMPUTE_C3 + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc] "+r"(wc0), + [ptr_out0] "+r"(pre_out0), + [wc0] "+r"(wc0) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif // __aarch64__ block_inr0 = block_inr2; block_inr1 = block_inr0 + in_len; diff --git a/lite/backends/arm/math/fp16/conv5x5s1_depthwise_fp16.cc b/lite/backends/arm/math/fp16/conv5x5s1_depthwise_fp16.cc index 55a0e48ef2c..8f23f88a5ef 100644 --- a/lite/backends/arm/math/fp16/conv5x5s1_depthwise_fp16.cc +++ b/lite/backends/arm/math/fp16/conv5x5s1_depthwise_fp16.cc @@ -320,6 +320,185 @@ namespace fp16 { "st1 {v7.8h}, [%[outc7]], #16\n" #else + +#define COMPUTE \ + "vld1.16 {d24-d25}, [%[bias]] \n" /* load bias to out00 */ \ + "vmov.u32 q13, q12 \n" /* mov bias to out01 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w0-w1 */ \ + "vmov.u32 q14, q12 \n" /* mov bias to out02 */ \ + "vld1.16 {d8-d11}, [%[inr0]]! \n" /* load inr0, 0-1 */ \ + "vmov.u32 q15, q12 \n" /* mov bias to out03 */ \ + "vld1.16 {d12-d15}, [%[inr0]]! \n" /* load inr0, 2-3 */ \ + /* out row0*/ \ + "vmla.f16 q12, q4, q0 \n" /* out00 = w0 * inr00 */ \ + "vmla.f16 q13, q5, q0 \n" /* out01 = w0 * inr01 */ \ + "vld1.16 {d16-d19}, [%[inr0]]! \n" /* load inr0, 4-5 */ \ + "vmla.f16 q14, q6, q0 \n" /* out02 = w0 * inr02 */ \ + "vmla.f16 q15, q7, q0 \n" /* out03 = w0 * inr03 */ \ + "vld1.16 {d20-d23}, [%[inr0]]! \n" /* load inr0, 6-7 */ \ + "vmla.f16 q12, q5, q1 \n" /* out00 = w1 * inr01 */ \ + "vmla.f16 q13, q6, q1 \n" /* out01 = w1 * inr02 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w2-w3 */ \ + "vmla.f16 q14, q7, q1 \n" /* out02 = w1 * inr03 */ \ + "vmla.f16 q15, q8, q1 \n" /* out03 = w1 * inr04 */ \ + "vld1.16 {d8-d11}, [%[inr1]]!\n" /* load inr1, 0-1 */ \ + "vmla.f16 q12, q6, q2 \n" /* out00 = w2 * inr02 */ \ + "vmla.f16 q13, q7, q2 \n" /* out01 = w2 * inr03 */ \ + "vmla.f16 q14, q8, q2 \n" /* out02 = w2 * inr04 */ \ + "vmla.f16 q15, q9, q2 \n" /* out03 = w2 * inr05 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w4-w5 */ \ + "vmla.f16 q12, q7, q3 \n" /* out00 = w3 * inr03 */ \ + "vmla.f16 q13, q8, q3 \n" /* out01 = w3 * inr04 */ \ + "vmla.f16 q14, q9, q3 \n" /* out02 = w3 * inr05 */ \ + "vmla.f16 q15, q10, q3 \n" /* out03 = w3 * inr06 */ \ + "vld1.16 {d12-d15}, [%[inr1]]!\n" /* load inr1, 2-3 */ \ + "vmla.f16 q12, q8, q0 \n" /* out00 = w4 * inr04 */ \ + "vmla.f16 q13, q9, q0 \n" /* out01 = w4 * inr05 */ \ + "vmla.f16 q14, q10, q0 \n" /* out02 = w4 * inr06 */ \ + "vmla.f16 q15, q11, q0 \n" /* out03 = w4 * inr07 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w6-w7 */ \ + /* out row1*/\ + "vmla.f16 q12, q4, q1 \n" /* out00 = w5 * inr10 */ \ + "vmla.f16 q13, q5, q1 \n" /* out01 = w5 * inr11 */ \ + "vmla.f16 q14, q6, q1 \n" /* out02 = w5 * inr12 */ \ + "vmla.f16 q15, q7, q1 \n" /* out03 = w5 * inr13 */ \ + "vld1.16 {d16-d19}, [%[inr1]]!\n" /* load inr1, 4-5 */ \ + "vmla.f16 q12, q5, q2 \n" /* out00 = w6 * inr11 */ \ + "vmla.f16 q13, q6, q2 \n" /* out01 = w6 * inr12 */ \ + "vmla.f16 q14, q7, q2 \n" /* out02 = w6 * inr13 */ \ + "vmla.f16 q15, q8, q2 \n" /* out03 = w6 * inr14 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w8-w9 */ \ + "vmla.f16 q12, q6, q3 \n" /* out00 = w7 * inr12 */ \ + "vmla.f16 q13, q7, q3 \n" /* out01 = w7 * inr13 */ \ + "vld1.16 {d20-d23}, [%[inr1]]!\n" /* load inr1, 6-7 */ \ + "vmla.f16 q14, q8, q3 \n" /* out02 = w7 * inr14 */ \ + "vmla.f16 q15, q9, q3 \n" /* out03 = w7 * inr15 */ \ + "vmla.f16 q12, q7, q0 \n" /* out00 = w8 * inr13 */ \ + "vmla.f16 q13, q8, q0 \n" /* out01 = w8 * inr14 */ \ + "vld1.16 {d8-d11}, [%[inr2]]!\n" /* load inr2, 0-1 */ \ + "vmla.f16 q14, q9, q0 \n" /* out02 = w8 * inr15 */ \ + "vmla.f16 q15, q10, q0 \n" /* out03 = w8 * inr16 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w10-w11 */ \ + "vmla.f16 q12, q8, q1 \n" /* out00 = w9 * inr14 */ \ + "vmla.f16 q13, q9, q1 \n" /* out01 = w9 * inr15 */ \ + "vld1.16 {d12-d15}, [%[inr2]]!\n" /* load inr2, 2-3 */ \ + "vmla.f16 q14, q10, q1 \n" /* out02 = w9 * inr16 */ \ + "vmla.f16 q15, q11, q1 \n" /* out03 = w9 * inr17 */ \ + /* out row3*/ \ + "vmla.f16 q12, q4, q2 \n" /* out00 = w10 * inr20 */ \ + "vmla.f16 q13, q5, q2 \n" /* out01 = w10 * inr21 */ \ + "vld1.16 {d16-d19}, [%[inr2]]!\n" /* load inr2, 4-5 */ \ + "vmla.f16 q14, q6, q2 \n" /* out02 = w10 * inr22 */ \ + "vmla.f16 q15, q7, q2 \n" /* out03 = w10 * inr23 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w12-w13 */ \ + "vmla.f16 q12, q5, q3 \n" /* out00 = w11 * inr21 */ \ + "vmla.f16 q13, q6, q3 \n" /* out01 = w11 * inr22 */ \ + "vld1.16 {d20-d23}, [%[inr2]]!\n" /* load inr2, 6-7 */ \ + "vmla.f16 q14, q7, q3 \n" /* out02 = w11 * inr23 */ \ + "vmla.f16 q15, q8, q3 \n" /* out03 = w11 * inr24 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w14-w15 */ \ + "vmla.f16 q12, q6, q0 \n" /* out00 = w12 * inr22 */ \ + "vmla.f16 q13, q7, q0 \n" /* out01 = w12 * inr23 */ \ + "vmla.f16 q14, q8, q0 \n" /* out02 = w12 * inr24 */ \ + "vmla.f16 q15, q9, q0 \n" /* out03 = w12 * inr25 */ \ + "vld1.16 {d8-d11}, [%[inr3]]!\n" /* load inr3, 0-1 */ \ + "vmla.f16 q12, q7, q1 \n" /* out00 = w13 * inr23 */ \ + "vmla.f16 q13, q8, q1 \n" /* out01 = w13 * inr24 */ \ + "vmla.f16 q14, q9, q1 \n" /* out02 = w13 * inr25 */ \ + "vmla.f16 q15, q10, q1 \n" /* out03 = w13 * inr26 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w16-w17 */ \ + "vmla.f16 q12, q8, q2 \n" /* out00 = w14 * inr24 */ \ + "vmla.f16 q13, q9, q2 \n" /* out01 = w14 * inr25 */ \ + "vld1.16 {d12-d15}, [%[inr3]]!\n" /* load inr3, 2-3 */ \ + "vmla.f16 q14, q10, q2 \n" /* out02 = w14 * inr26 */ \ + "vmla.f16 q15, q11, q2 \n" /* out03 = w14 * inr27 */ \ + /* out row3*/ \ + "vmla.f16 q12, q4, q3 \n" /* out00 = w15 * inr30 */ \ + "vmla.f16 q13, q5, q3 \n" /* out01 = w15 * inr31 */ \ + "vld1.16 {d16-d19}, [%[inr3]]!\n" /* load inr3, 4-5 */ \ + "vmla.f16 q14, q6, q3 \n" /* out02 = w15 * inr32 */ \ + "vmla.f16 q15, q7, q3 \n" /* out03 = w15 * inr33 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w18-w19 */ \ + "vmla.f16 q12, q5, q0 \n" /* out00 = w16 * inr31 */ \ + "vmla.f16 q13, q6, q0 \n" /* out01 = w16 * inr32 */ \ + "vld1.16 {d20-d23}, [%[inr3]]!\n" /* load inr3, 6-7 */ \ + "vmla.f16 q14, q7, q0 \n" /* out02 = w16 * inr33 */ \ + "vmla.f16 q15, q8, q0 \n" /* out03 = w16 * inr34 */ \ + "vmla.f16 q12, q6, q1 \n" /* out00 = w17 * inr32 */ \ + "vmla.f16 q13, q7, q1 \n" /* out01 = w17 * inr33 */ \ + "vmla.f16 q14, q8, q1 \n" /* out02 = w17 * inr34 */ \ + "vmla.f16 q15, q9, q1 \n" /* out03 = w17 * inr35 */ \ + "vld1.16 {d0-d3}, [%[wc0]]! \n" /* load w20-w21 */ \ + "vmla.f16 q12, q7, q2 \n" /* out00 = w18 * inr33 */ \ + "vmla.f16 q13, q8, q2 \n" /* out01 = w18 * inr34 */ \ + "vmla.f16 q14, q9, q2 \n" /* out02 = w18 * inr35 */ \ + "vmla.f16 q15, q10, q2 \n" /* out03 = w18 * inr36 */ \ + "vld1.16 {d8-d11}, [%[inr4]]!\n" /* load inr4, 0-1 */ \ + "vmla.f16 q12, q8, q3 \n" /* out00 = w19 * inr34 */ \ + "vmla.f16 q13, q9, q3 \n" /* out01 = w19 * inr35 */ \ + "vld1.16 {d12-d15}, [%[inr4]]!\n" /* load inr4, 2-3 */ \ + "vmla.f16 q14, q10, q3 \n" /* out02 = w19 * inr36 */ \ + "vmla.f16 q15, q11, q3 \n" /* out03 = w19 * inr37 */ \ + /* out row4 */ \ + "vmla.f16 q12, q4, q0 \n" /* out00 = w20 * inr40 */ \ + "vmla.f16 q13, q5, q0 \n" /* out01 = w20 * inr41 */ \ + "vld1.16 {d16-d19}, [%[inr4]]!\n" /* load inr4, 4-5 */ \ + "vmla.f16 q14, q6, q0 \n" /* out02 = w20 * inr42 */ \ + "vmla.f16 q15, q7, q0 \n" /* out03 = w20 * inr43 */ \ + "vld1.16 {d4-d7}, [%[wc0]]! \n" /* load w22-w23 */ \ + "vmla.f16 q12, q5, q1 \n" /* out00 = w21 * inr41 */ \ + "vmla.f16 q13, q6, q1 \n" /* out01 = w21 * inr42 */ \ + "vmla.f16 q14, q7, q1 \n" /* out02 = w21 * inr43 */ \ + "vmla.f16 q15, q8, q1 \n" /* out03 = w21 * inr44 */ \ + "vld1.16 {d20-d23}, [%[inr4]]!\n" /* load inr4, 6-7 */ \ + "vmla.f16 q12, q6, q2 \n" /* out00 = w22 * inr42 */ \ + "vmla.f16 q13, q7, q2 \n" /* out01 = w22 * inr43 */ \ + "vmla.f16 q14, q8, q2 \n" /* out02 = w22 * inr44 */ \ + "vmla.f16 q15, q9, q2 \n" /* out03 = w22 * inr45 */ \ + "vld1.16 {d4-d5}, [%[wc0]] \n" /* load w24 */ \ + "vmla.f16 q12, q7, q3 \n" /* out00 = w23 * inr43 */ \ + "vmla.f16 q13, q8, q3 \n" /* out01 = w23 * inr44 */ \ + "vmla.f16 q14, q9, q3 \n" /* out02 = w23 * inr45 */ \ + "vmla.f16 q15, q10, q3 \n" /* out03 = w23 * inr46 */ \ + "vmla.f16 q12, q8, q2 \n" /* out00 = w24 * inr44 */ \ + "vmla.f16 q13, q9, q2 \n" /* out01 = w24 * inr45 */ \ + "vmla.f16 q14, q10, q2 \n" /* out02 = w24 * inr46 */ \ + "vmla.f16 q15, q11, q2 \n" /* out03 = w24 * inr47 */ \ + +#define RELU /* relu */ \ + "vmov.u16 q0, #0\n" \ + "vld1.16 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f16 q12, q12, q0\n" \ + "vmax.f16 q13, q13, q0\n" \ + "vmax.f16 q14, q14, q0\n" \ + "vmax.f16 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f16 q12, q12, q1\n" \ + "vmin.f16 q13, q13, q1\n" \ + "vmin.f16 q14, q14, q1\n" \ + "vmin.f16 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u16 q0, #0\n" \ + "vld1.16 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f16 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f16 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f16 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f16 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f16 q3, q12, q1 @ mul \n" \ + "vmul.f16 q5, q13, q1 @ mul \n" \ + "vmul.f16 q7, q14, q1 @ mul \n" \ + "vmul.f16 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.16 {d24-d25}, [%[outc0]]\n" /* save outc0*/ \ + "vst1.16 {d26-d27}, [%[outc1]]\n" /* save outc1*/ \ + "vst1.16 {d28-d29}, [%[outc2]]\n" /* save outc2*/ \ + "vst1.16 {d30-d31}, [%[outc3]]\n" /* save outc3*/ + + #endif // clang-format on @@ -352,6 +531,9 @@ void act_switch_5x5s1(const float16_t* inr0, #ifdef __aarch64__ float16x8_t vsix = vdupq_n_f16(tmp); float16x8_t vscale = vdupq_n_f16(ss); +#else + float16x8_t vsix[8] = {tmp, tmp, tmp, tmp, tmp, tmp, tmp, tmp}; + float16x8_t vscale[8] = {ss, ss, ss, ss, ss, ss, ss, ss}; #endif switch (act_param.active_type) { case lite_api::ActivationType::kRelu: @@ -405,6 +587,36 @@ void act_switch_5x5s1(const float16_t* inr0, "v25", "v26"); #else + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; case lite_api::ActivationType::kRelu6: @@ -459,6 +671,36 @@ void act_switch_5x5s1(const float16_t* inr0, "v25", "v26"); #else + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; case lite_api::ActivationType::kLeakyRelu: @@ -513,6 +755,36 @@ void act_switch_5x5s1(const float16_t* inr0, "v25", "v26"); #else + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; default: @@ -571,6 +843,36 @@ void act_switch_5x5s1(const float16_t* inr0, "v25", "v26"); #else + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif } } @@ -593,7 +895,11 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, const int pad_w = paddings[2]; const int out_c_block = 8; const int out_h_kernel = 1; +#ifdef __aarch64__ const int out_w_kernel = 8; +#else + const int out_w_kernel = 4; +#endif const int win_ext = ow + 4; const int ow_round = ROUNDUP(ow, out_w_kernel); const int win_round = ROUNDUP(win_ext, out_w_kernel); @@ -644,7 +950,6 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, const float16_t* weight_c = weights + c * 25; // kernel_w * kernel_h float16_t* dout_c00 = dout_batch + c * size_out_channel; float16_t bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - #ifdef __aarch64__ float16x8_t w0 = vld1q_f16(weight_c); // w0, v23 float16x8_t w1 = vld1q_f16(weight_c + 8); // w1, v24 @@ -664,7 +969,16 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, } weight_c += 40; #else + if (flag_bias) { + for (int k = 0; k < 8 && c + k < oc; k++) { + bias_local[k] = bias[c + k]; + } + } + float16_t pre_out_[32]; + float16_t *pre_din0_ = &(pre_out_[0]), *pre_din1_ = &(pre_out_[8]), + *pre_din2_ = &(pre_out_[16]), *pre_din3_ = &(pre_out_[24]); #endif + for (int h = 0; h < oh; h += out_h_kernel) { float16_t* outc0 = dout_c00 + h * ow; float16_t* outc1 = outc0 + size_out_channel; @@ -753,6 +1067,62 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, bias_local, act_param); #else + act_switch_5x5s1(inr0, + inr1, + inr2, + inr3, + inr4, + pre_din0_, + pre_din1_, + pre_din2_, + pre_din3_, + pre_din0_, + pre_din1_, + pre_din2_, + pre_din3_, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + weight_c, + bias_local, + act_param); + asm volatile( + "vld1.32 {d0-d1}, [%[r0]]\n" + "vld1.32 {d2-d3}, [%[r1]]\n" + "vld1.32 {d4-d5}, [%[r2]]\n" + "vld1.32 {d6-d7}, [%[r3]]\n" + "vtrn.16 q0, q1\n" + "vtrn.16 q2, q3\n" + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + "vswp d1, d2\n" + "vswp d5, d6\n" + "vst1.16 {d0}, [%[outc0]]\n" + "vst1.16 {d1}, [%[outc1]]\n" + "vst1.16 {d4}, [%[outc2]]\n" + "vst1.16 {d5}, [%[outc3]]\n" + "vst1.16 {d2}, [%[outc4]]\n" + "vst1.16 {d3}, [%[outc5]]\n" + "vst1.16 {d6}, [%[outc6]]\n" + "vst1.16 {d7}, [%[outc7]]\n" + + : [r0] "+r"(pre_din0_), + [r1] "+r"(pre_din1_), + [r2] "+r"(pre_din2_), + [r3] "+r"(pre_din3_), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3), + [outc4] "+r"(outc4), + [outc5] "+r"(outc5), + [outc6] "+r"(outc6), + [outc7] "+r"(outc7) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); #endif if (flag_mask) { for (int i = 0; i < remain; ++i) { @@ -766,6 +1136,7 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, c7[i] = pre_out[i + 56]; } } +#ifdef __aarch64__ inr0 += 64; inr1 += 64; inr2 += 64; @@ -779,6 +1150,21 @@ void conv_depthwise_5x5s1_fp16(const float16_t* i_data, outc5 += 8; outc6 += 8; outc7 += 8; +#else + inr0 += 32; + inr1 += 32; + inr2 += 32; + inr3 += 32; + inr4 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; + outc4 += 4; + outc5 += 4; + outc6 += 4; + outc7 += 4; +#endif } } } diff --git a/lite/backends/arm/math/fp16/conv5x5s2_depthwise_fp16.cc b/lite/backends/arm/math/fp16/conv5x5s2_depthwise_fp16.cc index 358759389b5..a1d53de8b2c 100644 --- a/lite/backends/arm/math/fp16/conv5x5s2_depthwise_fp16.cc +++ b/lite/backends/arm/math/fp16/conv5x5s2_depthwise_fp16.cc @@ -215,6 +215,200 @@ namespace fp16 { "st1 {v7.4h}, [%[outc7]], #8\n" #else + +#define COMPUTE \ + /* fill with bias */ \ + "vld1.16 {d12-d13}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.16 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vld1.16 {d0-d3}, [%[inr0]]!\n" /* load input r0, 0,1*/ \ + "vand.i16 q12, q6, q6\n" \ + "vld1.16 {d4-d7}, [%[inr0]]!\n" /* load input r0, 2,3*/ \ + "vand.i16 q13, q6, q6\n" \ + "vld1.16 {d8-d11}, [%[inr0]]!\n" /* load input r0, 4,5*/ \ + "vand.i16 q14, q6, q6\n" \ + "vand.i16 q15, q6, q6\n" \ + "vld1.16 {d12-d13}, [%[inr0]]!\n" /* load input r0, 6*/ \ + "vmla.f16 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f16 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.16 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-q10 */ \ + "vmla.f16 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f16 q15, q7, q6 @ w0 * inr6\n" \ + "vmla.f16 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f16 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f16 q14, q8, q5 @ w1 * inr5\n" \ + "vld1.16 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vmla.f16 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f16 q13, q9, q4 @ w2 * inr6\n" \ + "vmla.f16 q14, q9, q6 @ w2 * inr4\n" \ + "vld1.16 {d0-d3}, [%[inr0]]! \n" /* load r0, 7-8 */ \ + "vmla.f16 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f16 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f16 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f16 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.16 {d4-d7}, [%[inr0]] \n" /* load r0, 9-10 */ \ + "vmla.f16 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f16 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f16 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f16 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.16 {d0-d3}, [%[inr1]]! @ load r1, 0, 1\n" \ + "vld1.16 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f16 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.16 {d4-d5}, [%[inr1]]! @ load r1, 2\n" \ + "sub %[inr0], %[inr0], #16 @ r0 - 16 to nextline address\n" \ + "vld1.16 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f16 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f16 q13, q7, q2 @ w0 * inr2\n" \ + "vmla.f16 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.16 {d6-d9}, [%[inr1]]! @ load r1, 3, 4\n" \ + "vld1.16 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vld1.16 {d10-d13}, [%[inr1]]! @ load r1, 5, 6\n" \ + "vmla.f16 q14, q7, q4 @ w0 * inr0\n" \ + "vmla.f16 q15, q7, q6 @ w0 * inr2\n" \ + "vmla.f16 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f16 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.16 {d0-d3}, [%[inr1]]! @ load r1, 7, 8\n" \ + "vmla.f16 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f16 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f16 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f16 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f16 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f16 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f16 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.16 {d4-d7}, [%[inr1]] @ load r1, 9, 10\n" \ + "vmla.f16 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f16 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f16 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.16 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f16 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f16 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f16 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f16 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.16 {d0-d3}, [%[inr2]]! @ load r2, 0, 1\n" \ + "vld1.16 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "sub %[inr1], %[inr1], #16 @ r1 - 16 to nextline address\n" \ + "vld1.16 {d4-d7}, [%[inr2]]! @ load r2, 2, 3\n" \ + "vld1.16 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vmla.f16 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f16 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.16 {d8-d11}, [%[inr2]]! @ load r2, 4, 5\n" \ + "vmla.f16 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f16 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.16 {d12-d13}, [%[inr2]]! @ load r2, 6 \n" \ + "vmla.f16 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f16 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.16 {d0-d3}, [%[inr2]]! @ load r2, 7, 8\n" \ + "vmla.f16 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f16 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f16 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f16 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f16 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f16 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f16 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f16 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.16 {d4-d7}, [%[inr2]] @ load r2, 9, 10\n" \ + "vmla.f16 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f16 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f16 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f16 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.16 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[inr2], %[inr2], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f16 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.16 {d0-d3}, [%[inr3]]! @ load r3, 0, 1\n" \ + "vmla.f16 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.16 {d4-d7}, [%[inr3]]! @ load r3, 2, 3\n" \ + "vld1.16 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f16 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f16 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.16 {d8-d11}, [%[inr3]]! @ load r3, 4, 5\n" \ + "vld1.16 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.16 {d12-d13}, [%[inr3]]! @ load r3, 6, \n" \ + "vmla.f16 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f16 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f16 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f16 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.16 {d0-d3}, [%[inr3]]! @ load r3, 7, 8\n" \ + "vmla.f16 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f16 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f16 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f16 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f16 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.16 {d4-d7}, [%[inr3]] @ load r3, 9, 10\n" \ + "vmla.f16 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f16 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f16 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f16 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f16 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f16 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f16 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.16 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[inr3], %[inr3], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f16 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.16 {d0-d3}, [%[inr4]]! @ load r4, 0, 1\n" \ + "vmla.f16 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.16 {d4-d7}, [%[inr4]]! @ load r4, 2, 3\n" \ + "vld1.16 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f16 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f16 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.16 {d8-d11}, [%[inr4]]! @ load r3, 4, 5\n" \ + "vld1.16 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.16 {d12-d13}, [%[inr4]]! @ load r3, 6, \n" \ + "vmla.f16 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f16 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f16 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f16 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.16 {d0-d3}, [%[inr4]]! @ load r3, 7, 8\n" \ + "vmla.f16 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f16 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f16 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f16 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f16 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.16 {d4-d7}, [%[inr4]] @ load r3, 9, 10\n" \ + "vmla.f16 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f16 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f16 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f16 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f16 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f16 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f16 q15, q10, q2 @ w3 * inr9\n" \ + "sub %[wc0], %[wc0], #400 @ wc0 - 400 to start address\n" \ + "sub %[inr4], %[inr4], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f16 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f16 q15, q11, q3 @ w4 * inr10\n" \ + +#define RELU /* relu */ \ + "vmov.u16 q0, #0\n" \ + "vld1.16 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f16 q12, q12, q0\n" \ + "vmax.f16 q13, q13, q0\n" \ + "vmax.f16 q14, q14, q0\n" \ + "vmax.f16 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f16 q12, q12, q1\n" \ + "vmin.f16 q13, q13, q1\n" \ + "vmin.f16 q14, q14, q1\n" \ + "vmin.f16 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u16 q0, #0\n" \ + "vld1.16 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f16 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f16 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f16 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f16 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f16 q3, q12, q1 @ mul \n" \ + "vmul.f16 q5, q13, q1 @ mul \n" \ + "vmul.f16 q7, q14, q1 @ mul \n" \ + "vmul.f16 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.16 {d24-d25}, [%[outc0]]\n" /* save outc0*/ \ + "vst1.16 {d26-d27}, [%[outc1]]\n" /* save outc1*/ \ + "vst1.16 {d28-d29}, [%[outc2]]\n" /* save outc2*/ \ + "vst1.16 {d30-d31}, [%[outc3]]\n" /* save outc3*/ + + #endif // clang-format on @@ -247,6 +441,9 @@ void act_switch_5x5s2(const float16_t* inr0, #ifdef __aarch64__ float16x8_t vsix = vdupq_n_f16(tmp); float16x8_t vscale = vdupq_n_f16(ss); +#else + float16x8_t vsix[8] = {tmp, tmp, tmp, tmp, tmp, tmp, tmp, tmp}; + float16x8_t vscale[8] = {ss, ss, ss, ss, ss, ss, ss, ss}; #endif switch (act_param.active_type) { case lite_api::ActivationType::kRelu: @@ -295,6 +492,36 @@ void act_switch_5x5s2(const float16_t* inr0, "v21", "v22"); #else + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; case lite_api::ActivationType::kRelu6: @@ -344,6 +571,36 @@ void act_switch_5x5s2(const float16_t* inr0, "v21", "v22"); #else + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; case lite_api::ActivationType::kLeakyRelu: @@ -393,6 +650,36 @@ void act_switch_5x5s2(const float16_t* inr0, "v21", "v22"); #else + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; default: @@ -446,6 +733,36 @@ void act_switch_5x5s2(const float16_t* inr0, "v21", "v22"); #else + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif } } @@ -471,8 +788,8 @@ void conv_depthwise_5x5s2_fp16(const float16_t* i_data, const int out_h_kernel = 1; const int out_w_kernel = 4; const int win_ext = ow * 2 + 3; - const int ow_round = ROUNDUP(ow, 4); - const int win_round = ROUNDUP(win_ext, 4); + const int ow_round = ROUNDUP(ow, out_w_kernel); + const int win_round = ROUNDUP(win_ext, out_w_kernel); const int hin_round = oh * 2 + 3; const int prein_size = win_round * hin_round * out_c_block; auto workspace_size = threads * prein_size + win_round + ow_round; @@ -493,10 +810,10 @@ void conv_depthwise_5x5s2_fp16(const float16_t* i_data, int we = ws + win_round; int hs = -pad_h; int he = hs + hin_round; - int w_loop = ow_round / 4; - auto remain = w_loop * 4 - ow; + int w_loop = ow_round / out_w_kernel; + auto remain = w_loop * out_w_kernel - ow; bool flag_remain = remain > 0; - remain = 4 - remain; + remain = out_w_kernel - remain; remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; @@ -527,6 +844,7 @@ void conv_depthwise_5x5s2_fp16(const float16_t* i_data, float16x8_t w3 = vld1q_f16(weight_c + 24); // w3, v26 float16x8_t w4 = vld1q_f16(weight_c + 32); // w4, v27 float16x8_t vbias = vdupq_n_f16(0.f); + weight_c += 40; if (flag_bias) { if (c + out_c_block < oc) { vbias = vld1q_f16(&bias[c]); // v28 @@ -538,7 +856,12 @@ void conv_depthwise_5x5s2_fp16(const float16_t* i_data, vbias = vld1q_f16(bias_local); // v28 } } - weight_c += 40; +#else + if (flag_bias) { + for (int k = 0; k < 8 && c + k < oc; k++) { + bias_local[k] = bias[c + k]; + } + } #endif for (int h = 0; h < oh; h += out_h_kernel) { float16_t* outc0 = dout_c00 + h * ow; @@ -628,6 +951,67 @@ void conv_depthwise_5x5s2_fp16(const float16_t* i_data, bias_local, act_param); #else + float16_t pre_out_[32]; + float16_t *pre_din0_ = &(pre_out_[0]), *pre_din1_ = &(pre_out_[8]), + *pre_din2_ = &(pre_out_[16]), *pre_din3_ = &(pre_out_[24]); + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + &pre_out_[0], + &pre_out_[8], + &pre_out_[16], + &pre_out_[24], + &pre_out[32], + &pre_out[40], + &pre_out[48], + &pre_out[56], + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + weight_c, + bias_local, + act_param); + asm volatile( + "vld1.32 {d0-d1}, [%[r0]]\n" + "vld1.32 {d2-d3}, [%[r1]]\n" + "vld1.32 {d4-d5}, [%[r2]]\n" + "vld1.32 {d6-d7}, [%[r3]]\n" + "vtrn.16 q0, q1\n" + "vtrn.16 q2, q3\n" + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vswp d1, d2\n" + "vswp d5, d6\n" + "vst1.16 {d0}, [%[outc0]]\n" + "vst1.16 {d1}, [%[outc1]]\n" + "vst1.16 {d4}, [%[outc2]]\n" + "vst1.16 {d5}, [%[outc3]]\n" + "vst1.16 {d2}, [%[outc4]]\n" + "vst1.16 {d3}, [%[outc5]]\n" + "vst1.16 {d6}, [%[outc6]]\n" + "vst1.16 {d7}, [%[outc7]]\n" + + : [r0] "+r"(pre_din0_), + [r1] "+r"(pre_din1_), + [r2] "+r"(pre_din2_), + [r3] "+r"(pre_din3_), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3), + [outc4] "+r"(outc4), + [outc5] "+r"(outc5), + [outc6] "+r"(outc6), + [outc7] "+r"(outc7) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); + #endif if (flag_mask) { for (int i = 0; i < remain; ++i) { diff --git a/lite/backends/arm/math/fp16/conv_block_utils_fp16.h b/lite/backends/arm/math/fp16/conv_block_utils_fp16.h index 5c2e3324041..dda2d04807b 100644 --- a/lite/backends/arm/math/fp16/conv_block_utils_fp16.h +++ b/lite/backends/arm/math/fp16/conv_block_utils_fp16.h @@ -231,12 +231,15 @@ inline void prepack_input_nxwc4(const float16_t* din, "vtrn.16 q2, q3\n" "vtrn.32 q0, q2\n" "vtrn.32 q1, q3\n" - "vswp d1, d4\n" - "vswp d5, d7\n" + + "vswp d1, d2\n" + "vswp d5, d6\n" "subs %[cnt], #1\n" - "vst1.32 {d0-d3}, [%[ptr_out]]!\n" - "vst1.32 {d4-d7}, [%[ptr_out]]!\n" + "vst1.16 {d0-d1}, [%[ptr_out]]!\n" + "vst1.16 {d4-d5}, [%[ptr_out]]!\n" + "vst1.16 {d2-d3}, [%[ptr_out]]!\n" + "vst1.16 {d6-d7}, [%[ptr_out]]!\n" "bne 1b\n" : [cnt] "+r"(cnt), [r0] "+r"(ptr_c0), diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h index c2409438df9..6276ebc0f3e 100644 --- a/lite/kernels/arm/conv_direct.h +++ b/lite/kernels/arm/conv_direct.h @@ -214,7 +214,7 @@ inline bool direct_conv_trans_weights( wout->Resize({cround, ic, kh, kw}); auto w_in_data = win->data(); auto transed_w_data = wout->mutable_data(); - if (ic == 3 && stride == 2 && (oc % 4 == 0)) { + if (ic == 3 && stride == 2 && (oc % 8 == 0)) { // [chout, 3, kh, kw] -> [chout / cblock, kh, kw, 3, cblock] lite::arm::math::conv_trans_weights_c4toc12( w_in_data, transed_w_data, oc, ic, cblock, kh * kw);