@@ -189,8 +189,6 @@ namespace fp16 {
189189 : "cc", \
190190 "memory", \
191191 ASM_VAR);
192-
193-
194192#else
195193#define INIT_1 \
196194 " vld1.16 {d0-d1}, [%[dinx_ptr]]! \n " \
@@ -260,10 +258,10 @@ namespace fp16 {
260258
261259#define SIMPLE_COMPUTE_TYPE (op ) \
262260 asm volatile (INIT SIMPLE_COMPUTE (v##op.f16 ) STORE \
261+ : [dinx_ptr] "+r"(dinx_ptr), \
262+ [diny_ptr] "+r"(diny_ptr), \
263+ [dout_ptr] "+r"(dout_ptr) \
263264 : \
264- : [dinx_ptr] "r"(dinx_ptr), \
265- [diny_ptr] "r"(diny_ptr), \
266- [dout_ptr] "r"(dout_ptr) \
267265 : "cc", \
268266 "memory", \
269267 ASM_VAR);
@@ -281,11 +279,10 @@ namespace fp16 {
281279
282280#define SIMPLE_COMPUTE_TYPE_RELU (op ) \
283281 asm volatile (INIT SIMPLE_COMPUTE (v##op.f16 ) RELU STORE \
284- : \
285- : [dinx_ptr] "r"(dinx_ptr), \
286- [diny_ptr] "r"(diny_ptr), \
287- [dout_ptr] "r"(dout_ptr), \
288- [vzero] "w"(vzero) \
282+ : [dinx_ptr] "+r"(dinx_ptr), \
283+ [diny_ptr] "+r"(diny_ptr), \
284+ [dout_ptr] "+r"(dout_ptr) \
285+ : [vzero] "w"(vzero) \
289286 : "cc", \
290287 "memory", \
291288 ASM_VAR);
@@ -303,10 +300,9 @@ namespace fp16 {
303300
304301#define SIMPLE_COMPUTE_TYPE_BROADCAST (op ) \
305302 asm volatile (INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST (v##op.f16 ) STORE \
306- : \
307- : [dinx_ptr] "r"(dinx_ptr_1), \
308- [dout_ptr] "r"(dout_ptr_1), \
309- [val_y] "w"(val_y) \
303+ : [dinx_ptr] "+r"(dinx_ptr_1), \
304+ [dout_ptr] "+r"(dout_ptr_1) \
305+ : [val_y] "w"(val_y) \
310306 : "cc", \
311307 "memory", \
312308 ASM_VAR);
@@ -323,17 +319,16 @@ namespace fp16 {
323319
324320#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU (op ) \
325321 asm volatile (INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST (v##op.f16 ) RELU STORE \
326- : \
327- : [dinx_ptr] "r"(dinx_ptr_1), \
328- [dout_ptr] "r"(dout_ptr_1), \
329- [val_y] "w"(val_y), \
322+ : [dinx_ptr] "+r"(dinx_ptr_1), \
323+ [dout_ptr] "+r"(dout_ptr_1) \
324+ : [val_y] "w"(val_y), \
330325 [vzero] "w"(vzero) \
331326 : "cc", \
332327 "memory", \
333328 ASM_VAR);
334329
335330#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU_1 (op ) \
336- asm volatile (INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST (v##op.f16 ) RELU STORE_1 \
331+ asm volatile (INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST (v##op.f16 ) RELU_1 STORE_1 \
337332 : [cnt_num] "+r"(cnt_num), \
338333 [dinx_ptr] "+r"(dinx_ptr_1), \
339334 [dout_ptr] "+r"(dout_ptr_1) \
@@ -352,7 +347,6 @@ namespace fp16 {
352347 float16_t * dout, \
353348 int num) { \
354349 LOOP_CNT (num) \
355- \
356350 for (int i = 0 ; i < cnt; i++) { \
357351 int stride = i << 5 ; \
358352 const float16_t * dinx_ptr = dinx + stride; \
@@ -517,6 +511,110 @@ elmentwise_simple_compute(mul);
517511elmentwise_simple_compute (sub);
518512#ifdef __aarch64__
519513elmentwise_simple_compute (div);
514+ #else
515+ void elementwise_div (const float16_t * dinx,
516+ const float16_t * diny,
517+ float16_t * dout,
518+ int num) {
519+ LOOP_CNT (num)
520+ for (int i = 0 ; i < cnt; i++) {
521+ int stride = i << 5 ;
522+ const float16_t * dinx_ptr = dinx + stride;
523+ const float16_t * diny_ptr = diny + stride;
524+ float16_t * dout_ptr = dout + stride;
525+ float16x8_t vec_a1 = vld1q_f16 (dinx_ptr);
526+ float16x8_t vec_a2 = vld1q_f16 (dinx_ptr + 8 );
527+ float16x8_t vec_b1 = vld1q_f16 (diny_ptr);
528+ float16x8_t vec_b2 = vld1q_f16 (diny_ptr + 8 );
529+ vst1q_f16 (dout_ptr, divq_ps_f16 (vec_a1, vec_b1));
530+ vst1q_f16 (dout_ptr + 8 , divq_ps_f16 (vec_a2, vec_b2));
531+ vec_a1 = vld1q_f16 (dinx_ptr + 16 );
532+ vec_a2 = vld1q_f16 (dinx_ptr + 24 );
533+ vec_b1 = vld1q_f16 (diny_ptr + 16 );
534+ vec_b2 = vld1q_f16 (diny_ptr + 24 );
535+ vst1q_f16 (dout_ptr + 16 , divq_ps_f16 (vec_a1, vec_b1));
536+ vst1q_f16 (dout_ptr + 24 , divq_ps_f16 (vec_a2, vec_b2));
537+ }
538+ int stride = cnt << 5 ;
539+ if (rem_cnt > 0 ) {
540+ const float16_t * dinx_ptr = dinx + stride;
541+ const float16_t * diny_ptr = diny + stride;
542+ float16_t * dout_ptr = dout + stride;
543+ int cnt_num = rem_cnt;
544+ for (int loop = 0 ; loop < rem_cnt; loop++) {
545+ float16x8_t vec_a1 = vld1q_f16 (dinx_ptr + loop * 8 );
546+ float16x8_t vec_b1 = vld1q_f16 (diny_ptr + loop * 8 );
547+ vst1q_f16 (dout_ptr + loop * 8 , divq_ps_f16 (vec_a1, vec_b1));
548+ }
549+ }
550+ if (rem_rem > 0 ) {
551+ stride += (rem_cnt << 3 );
552+ const float16_t * dinx_ptr = dinx + stride;
553+ const float16_t * diny_ptr = diny + stride;
554+ float16_t * dout_ptr = dout + stride;
555+ for (int i = 0 ; i < rem_rem; i++) {
556+ *dout_ptr = naive_div (*dinx_ptr, *diny_ptr);
557+ dout_ptr++;
558+ dinx_ptr++;
559+ diny_ptr++;
560+ }
561+ }
562+ }
563+
564+ void elementwise_div_broadcast (const float16_t * dinx,
565+ const float16_t * diny,
566+ float16_t * dout,
567+ int batch,
568+ int channels,
569+ int num) {
570+ OMP_PARA_INTERNAL_COLLASPE_2
571+ for (int i = 0 ; i < batch; ++i) {
572+ for (int j = 0 ; j < channels; ++j) {
573+ int offset = (i * channels + j) * num;
574+ const auto * dinx_ptr = dinx + offset;
575+ const auto * diny_ptr = diny + j;
576+ auto * dout_ptr = dout + offset;
577+ LOOP_CNT (num)
578+ for (int k = 0 ; k < cnt; k++) {
579+ int stride = k << 5 ;
580+ const float16_t * dinx_ptr_1 = dinx_ptr + stride;
581+ float16_t * dout_ptr_1 = dout_ptr + stride;
582+ float16x8_t val_y = vdupq_n_f16 (diny_ptr[0 ]);
583+ float16x8_t vec_x1 = vld1q_f16 (dinx_ptr_1);
584+ float16x8_t vec_x2 = vld1q_f16 (dinx_ptr_1 + 8 );
585+ vst1q_f16 (dout_ptr_1, divq_ps_f16 (vec_x1, val_y));
586+ vst1q_f16 (dout_ptr_1 + 8 , divq_ps_f16 (vec_x2, val_y));
587+ vec_x1 = vld1q_f16 (dinx_ptr_1 + 16 );
588+ vec_x2 = vld1q_f16 (dinx_ptr_1 + 24 );
589+ vst1q_f16 (dout_ptr_1 + 16 , divq_ps_f16 (vec_x1, val_y));
590+ vst1q_f16 (dout_ptr_1 + 24 , divq_ps_f16 (vec_x2, val_y));
591+ }
592+ int stride = cnt << 5 ;
593+ if (rem_cnt > 0 ) {
594+ const float16_t * dinx_ptr_1 = dinx_ptr + stride;
595+ float16_t * dout_ptr_1 = dout_ptr + stride;
596+ float16x8_t val_y = vdupq_n_f16 (diny_ptr[0 ]);
597+ int cnt_num = rem_cnt;
598+ for (int loop = 0 ; loop < rem_cnt; loop++) {
599+ float16x8_t vec_x1 = vld1q_f16 (dinx_ptr_1 + loop * 8 );
600+ vst1q_f16 (dout_ptr_1 + loop * 8 , divq_ps_f16 (vec_x1, val_y));
601+ }
602+ }
603+ if (rem_rem > 0 ) {
604+ stride += (rem_cnt << 3 );
605+ const float16_t * dinx_ptr_1 = dinx_ptr + stride;
606+ float16_t * dout_ptr_1 = dout_ptr + stride;
607+ float16_t val = diny_ptr[0 ];
608+ for (int i = 0 ; i < rem_rem; i++) {
609+ *dout_ptr_1 = naive_div (*dinx_ptr_1, val);
610+ dinx_ptr_1++;
611+ dout_ptr_1++;
612+ }
613+ }
614+ }
615+ }
616+ }
617+
520618#endif
521619} // namespace fp16
522620} // namespace math
0 commit comments