Skip to content

Commit 7e97d30

Browse files
committed
ggml-cpu: add repack GEMM and GEMV for floating-point (#4)
1 parent a17373d commit 7e97d30

File tree

3 files changed

+101
-86
lines changed

3 files changed

+101
-86
lines changed

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
343343

344344
template<int ncols_interleaved>
345345
static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
346+
GGML_UNUSED(bs);
347+
346348
const int nb = n / 1;
347349

348350
assert (nr == 1);
@@ -369,39 +371,41 @@ static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t
369371
}
370372

371373
void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
372-
#if defined __riscv_v_intrinsic
374+
#if defined __riscv_zvfh
373375
ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
374-
return;
375-
#endif
376+
#else
376377
ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
378+
#endif
377379
}
378380

379381
void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
380-
#if defined __riscv_v_intrinsic
382+
#if defined __riscv_zvfh
381383
ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
382-
return;
383-
#endif
384+
#else
384385
ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
386+
#endif
385387
}
386388

387389
void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
388-
#if defined __riscv_v_intrinsic
390+
#if defined __riscv_zvfh
389391
ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
390-
return;
391-
#endif
392+
#else
392393
ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
394+
#endif
393395
}
394396

395397
void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
396-
#if defined __riscv_v_intrinsic
398+
#if defined __riscv_zvfh
397399
ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
398-
return;
399-
#endif
400+
#else
400401
ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
402+
#endif
401403
}
402404

403405
template<int ncols_interleaved>
404406
static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
407+
GGML_UNUSED(bs);
408+
405409
const int nb = n / 1;
406410

407411
assert (nr == 1);
@@ -428,35 +432,35 @@ static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t
428432
}
429433

430434
void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
431-
#if defined __riscv_v_intrinsic
435+
#if defined __riscv_zvfh
432436
ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
433-
return;
434-
#endif
437+
#else
435438
ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
439+
#endif
436440
}
437441

438442
void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
439-
#if defined __riscv_v_intrinsic
443+
#if defined __riscv_zvfh
440444
ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
441-
return;
442-
#endif
445+
#else
443446
ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
447+
#endif
444448
}
445449

446450
void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
447-
#if defined __riscv_v_intrinsic
451+
#if defined __riscv_zvfh
448452
ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
449-
return;
450-
#endif
453+
#else
451454
ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
455+
#endif
452456
}
453457

454458
void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
455-
#if defined __riscv_v_intrinsic
459+
#if defined __riscv_zvfh
456460
ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
457-
return;
458-
#endif
461+
#else
459462
ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
463+
#endif
460464
}
461465

462466
template<int ncols_interleaved>
@@ -506,35 +510,35 @@ static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_
506510
}
507511

508512
void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
509-
#if defined __riscv_v_intrinsic
513+
#if defined __riscv_zvfh
510514
ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
511-
return;
512-
#endif
515+
#else
513516
ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
517+
#endif
514518
}
515519

516520
void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
517-
#if defined __riscv_v_intrinsic
521+
#if defined __riscv_zvfh
518522
ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
519-
return;
520-
#endif
523+
#else
521524
ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
525+
#endif
522526
}
523527

524528
void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
525-
#if defined __riscv_v_intrinsic
529+
#if defined __riscv_zvfh
526530
ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
527-
return;
528-
#endif
531+
#else
529532
ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
533+
#endif
530534
}
531535

532536
void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
533-
#if defined __riscv_v_intrinsic
537+
#if defined __riscv_zvfh
534538
ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
535-
return;
536-
#endif
539+
#else
537540
ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
541+
#endif
538542
}
539543

540544
template<int ncols_interleaved>
@@ -584,33 +588,33 @@ static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_
584588
}
585589

586590
void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
587-
#if defined __riscv_v_intrinsic
591+
#if defined __riscv_zvfh
588592
ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
589-
return;
590-
#endif
593+
#else
591594
ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
595+
#endif
592596
}
593597

594598
void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
595-
#if defined __riscv_v_intrinsic
599+
#if defined __riscv_zvfh
596600
ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
597-
return;
598-
#endif
601+
#else
599602
ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
603+
#endif
600604
}
601605

602606
void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
603-
#if defined __riscv_v_intrinsic
607+
#if defined __riscv_zvfh
604608
ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
605-
return;
606-
#endif
609+
#else
607610
ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
611+
#endif
608612
}
609613

610614
void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
611-
#if defined __riscv_v_intrinsic
615+
#if defined __riscv_zvfh
612616
ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
613-
return;
614-
#endif
617+
#else
615618
ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
619+
#endif
616620
}

0 commit comments

Comments
 (0)