Skip to content

Commit 7984480

Browse files
authored
[ARMv7] add elementwise_div_fp16 && fix elementwise_fp16 bug (#10050)
* [ARMv7] add elementwise_div_fp16 && fix elementwise_fp16 bug * fix f32_2_fp16 bug in armv7
1 parent df15938 commit 7984480

4 files changed

Lines changed: 145 additions & 31 deletions

File tree

lite/backends/arm/math/fp16/elementwise_fp16.cc

Lines changed: 118 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,6 @@ namespace fp16 {
189189
: "cc", \
190190
"memory", \
191191
ASM_VAR);
192-
193-
194192
#else
195193
#define INIT_1 \
196194
"vld1.16 {d0-d1}, [%[dinx_ptr]]! \n" \
@@ -260,10 +258,10 @@ namespace fp16 {
260258

261259
#define SIMPLE_COMPUTE_TYPE(op) \
262260
asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) STORE \
261+
: [dinx_ptr] "+r"(dinx_ptr), \
262+
[diny_ptr] "+r"(diny_ptr), \
263+
[dout_ptr] "+r"(dout_ptr) \
263264
: \
264-
: [dinx_ptr] "r"(dinx_ptr), \
265-
[diny_ptr] "r"(diny_ptr), \
266-
[dout_ptr] "r"(dout_ptr) \
267265
: "cc", \
268266
"memory", \
269267
ASM_VAR);
@@ -281,11 +279,10 @@ namespace fp16 {
281279

282280
#define SIMPLE_COMPUTE_TYPE_RELU(op) \
283281
asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) RELU STORE \
284-
: \
285-
: [dinx_ptr] "r"(dinx_ptr), \
286-
[diny_ptr] "r"(diny_ptr), \
287-
[dout_ptr] "r"(dout_ptr), \
288-
[vzero] "w"(vzero) \
282+
: [dinx_ptr] "+r"(dinx_ptr), \
283+
[diny_ptr] "+r"(diny_ptr), \
284+
[dout_ptr] "+r"(dout_ptr) \
285+
: [vzero] "w"(vzero) \
289286
: "cc", \
290287
"memory", \
291288
ASM_VAR);
@@ -303,10 +300,9 @@ namespace fp16 {
303300

304301
#define SIMPLE_COMPUTE_TYPE_BROADCAST(op) \
305302
asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) STORE \
306-
: \
307-
: [dinx_ptr] "r"(dinx_ptr_1), \
308-
[dout_ptr] "r"(dout_ptr_1), \
309-
[val_y] "w"(val_y) \
303+
: [dinx_ptr] "+r"(dinx_ptr_1), \
304+
[dout_ptr] "+r"(dout_ptr_1) \
305+
: [val_y] "w"(val_y) \
310306
: "cc", \
311307
"memory", \
312308
ASM_VAR);
@@ -323,17 +319,16 @@ namespace fp16 {
323319

324320
#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU(op) \
325321
asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) RELU STORE \
326-
: \
327-
: [dinx_ptr] "r"(dinx_ptr_1), \
328-
[dout_ptr] "r"(dout_ptr_1), \
329-
[val_y] "w"(val_y), \
322+
: [dinx_ptr] "+r"(dinx_ptr_1), \
323+
[dout_ptr] "+r"(dout_ptr_1) \
324+
: [val_y] "w"(val_y), \
330325
[vzero] "w"(vzero) \
331326
: "cc", \
332327
"memory", \
333328
ASM_VAR);
334329

335330
#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU_1(op) \
336-
asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU STORE_1 \
331+
asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU_1 STORE_1 \
337332
: [cnt_num] "+r"(cnt_num), \
338333
[dinx_ptr] "+r"(dinx_ptr_1), \
339334
[dout_ptr] "+r"(dout_ptr_1) \
@@ -352,7 +347,6 @@ namespace fp16 {
352347
float16_t* dout, \
353348
int num) { \
354349
LOOP_CNT(num) \
355-
\
356350
for (int i = 0; i < cnt; i++) { \
357351
int stride = i << 5; \
358352
const float16_t* dinx_ptr = dinx + stride; \
@@ -517,6 +511,110 @@ elmentwise_simple_compute(mul);
517511
elmentwise_simple_compute(sub);
518512
#ifdef __aarch64__
519513
elmentwise_simple_compute(div);
514+
#else
515+
void elementwise_div(const float16_t* dinx,
516+
const float16_t* diny,
517+
float16_t* dout,
518+
int num) {
519+
LOOP_CNT(num)
520+
for (int i = 0; i < cnt; i++) {
521+
int stride = i << 5;
522+
const float16_t* dinx_ptr = dinx + stride;
523+
const float16_t* diny_ptr = diny + stride;
524+
float16_t* dout_ptr = dout + stride;
525+
float16x8_t vec_a1 = vld1q_f16(dinx_ptr);
526+
float16x8_t vec_a2 = vld1q_f16(dinx_ptr + 8);
527+
float16x8_t vec_b1 = vld1q_f16(diny_ptr);
528+
float16x8_t vec_b2 = vld1q_f16(diny_ptr + 8);
529+
vst1q_f16(dout_ptr, divq_ps_f16(vec_a1, vec_b1));
530+
vst1q_f16(dout_ptr + 8, divq_ps_f16(vec_a2, vec_b2));
531+
vec_a1 = vld1q_f16(dinx_ptr + 16);
532+
vec_a2 = vld1q_f16(dinx_ptr + 24);
533+
vec_b1 = vld1q_f16(diny_ptr + 16);
534+
vec_b2 = vld1q_f16(diny_ptr + 24);
535+
vst1q_f16(dout_ptr + 16, divq_ps_f16(vec_a1, vec_b1));
536+
vst1q_f16(dout_ptr + 24, divq_ps_f16(vec_a2, vec_b2));
537+
}
538+
int stride = cnt << 5;
539+
if (rem_cnt > 0) {
540+
const float16_t* dinx_ptr = dinx + stride;
541+
const float16_t* diny_ptr = diny + stride;
542+
float16_t* dout_ptr = dout + stride;
543+
int cnt_num = rem_cnt;
544+
for (int loop = 0; loop < rem_cnt; loop++) {
545+
float16x8_t vec_a1 = vld1q_f16(dinx_ptr + loop * 8);
546+
float16x8_t vec_b1 = vld1q_f16(diny_ptr + loop * 8);
547+
vst1q_f16(dout_ptr + loop * 8, divq_ps_f16(vec_a1, vec_b1));
548+
}
549+
}
550+
if (rem_rem > 0) {
551+
stride += (rem_cnt << 3);
552+
const float16_t* dinx_ptr = dinx + stride;
553+
const float16_t* diny_ptr = diny + stride;
554+
float16_t* dout_ptr = dout + stride;
555+
for (int i = 0; i < rem_rem; i++) {
556+
*dout_ptr = naive_div(*dinx_ptr, *diny_ptr);
557+
dout_ptr++;
558+
dinx_ptr++;
559+
diny_ptr++;
560+
}
561+
}
562+
}
563+
564+
void elementwise_div_broadcast(const float16_t* dinx,
565+
const float16_t* diny,
566+
float16_t* dout,
567+
int batch,
568+
int channels,
569+
int num) {
570+
OMP_PARA_INTERNAL_COLLASPE_2
571+
for (int i = 0; i < batch; ++i) {
572+
for (int j = 0; j < channels; ++j) {
573+
int offset = (i * channels + j) * num;
574+
const auto* dinx_ptr = dinx + offset;
575+
const auto* diny_ptr = diny + j;
576+
auto* dout_ptr = dout + offset;
577+
LOOP_CNT(num)
578+
for (int k = 0; k < cnt; k++) {
579+
int stride = k << 5;
580+
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
581+
float16_t* dout_ptr_1 = dout_ptr + stride;
582+
float16x8_t val_y = vdupq_n_f16(diny_ptr[0]);
583+
float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1);
584+
float16x8_t vec_x2 = vld1q_f16(dinx_ptr_1 + 8);
585+
vst1q_f16(dout_ptr_1, divq_ps_f16(vec_x1, val_y));
586+
vst1q_f16(dout_ptr_1 + 8, divq_ps_f16(vec_x2, val_y));
587+
vec_x1 = vld1q_f16(dinx_ptr_1 + 16);
588+
vec_x2 = vld1q_f16(dinx_ptr_1 + 24);
589+
vst1q_f16(dout_ptr_1 + 16, divq_ps_f16(vec_x1, val_y));
590+
vst1q_f16(dout_ptr_1 + 24, divq_ps_f16(vec_x2, val_y));
591+
}
592+
int stride = cnt << 5;
593+
if (rem_cnt > 0) {
594+
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
595+
float16_t* dout_ptr_1 = dout_ptr + stride;
596+
float16x8_t val_y = vdupq_n_f16(diny_ptr[0]);
597+
int cnt_num = rem_cnt;
598+
for (int loop = 0; loop < rem_cnt; loop++) {
599+
float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1 + loop * 8);
600+
vst1q_f16(dout_ptr_1 + loop * 8, divq_ps_f16(vec_x1, val_y));
601+
}
602+
}
603+
if (rem_rem > 0) {
604+
stride += (rem_cnt << 3);
605+
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
606+
float16_t* dout_ptr_1 = dout_ptr + stride;
607+
float16_t val = diny_ptr[0];
608+
for (int i = 0; i < rem_rem; i++) {
609+
*dout_ptr_1 = naive_div(*dinx_ptr_1, val);
610+
dinx_ptr_1++;
611+
dout_ptr_1++;
612+
}
613+
}
614+
}
615+
}
616+
}
617+
520618
#endif
521619
} // namespace fp16
522620
} // namespace math

lite/backends/arm/math/fp16/elementwise_fp16.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,21 @@ typedef __fp16 float16_t;
4848
elementwise_simple_compute_declare(add);
4949
elementwise_simple_compute_declare(mul);
5050
elementwise_simple_compute_declare(sub);
51+
#ifdef __aarch64__
5152
elementwise_simple_compute_declare(div);
53+
#else
54+
void elementwise_div(const float16_t* dinx,
55+
const float16_t* diny,
56+
float16_t* dout,
57+
int num);
58+
59+
void elementwise_div_broadcast(const float16_t* dinx,
60+
const float16_t* diny,
61+
float16_t* dout,
62+
int batch,
63+
int channels,
64+
int num);
65+
#endif
5266

5367
} // namespace fp16
5468
} // namespace math

lite/backends/arm/math/fp16/type_trans_fp16.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,6 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) {
304304
"vst1.32 {d16-d17}, [%[out]]!\n"
305305
"bne 4b\n"
306306
"2: \n"
307-
"cmp %[remain_remain], #1\n"
308-
"blt 3f\n"
309-
"5: \n"
310-
"vld1.16 d0[0], [%[in]]!\n"
311-
"subs %[remain_remain], #1\n"
312-
"vcvt.f16.f32 d16, q0\n"
313-
"vst1.32 d16[0], [%[out]]!\n"
314-
"bne 5b\n"
315-
"3: \n"
316307
: [in] "+r"(in),
317308
[out] "+r"(out),
318309
[cnt] "+r"(cnt),
@@ -333,6 +324,11 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) {
333324
"q9",
334325
"q10",
335326
"q11");
327+
for (int i = 0; i < remain_remain; i++) {
328+
*out = static_cast<float16_t>(*in);
329+
out++;
330+
in++;
331+
}
336332
#endif
337333
}
338334
} // namespace fp16

lite/kernels/arm/elementwise_compute.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,18 @@ void ElementwiseDivCompute<float16_t, PRECISION(kFP16)>::Run() {
452452
OprandSwapable::NO,
453453
arm_math::NullNeonConfig>(
454454
this,
455-
456455
lite::arm::math::fp16::elementwise_div_broadcast<float16_t>,
457456
lite::arm::math::fp16::elementwise_div<float16_t>,
458457
paddle::lite::kernels::host::naive_div<float16_t>);
459458
#else
460-
LOG(FATAL) << "it doesn't support v7 fp16 elementwise_div compute";
459+
elementwise_compute_template<operators::ElementwiseParam,
460+
float16_t,
461+
OprandSwapable::NO,
462+
arm_math::NullNeonConfig>(
463+
this,
464+
lite::arm::math::fp16::elementwise_div_broadcast,
465+
lite::arm::math::fp16::elementwise_div,
466+
paddle::lite::kernels::host::naive_div<float16_t>);
461467
#endif
462468
}
463469

0 commit comments

Comments
 (0)