diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index b8d0d8defe..0f3caece74 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -708,7 +708,8 @@ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, in } void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_ref(x, y, k); + iqk_quantize_q4_0(x, y, k); + //quantize_row_q4_0_ref(x, y, k); } diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 5503bdf971..f4ef0c8c34 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -248,6 +248,12 @@ static inline float hmax_float_8(__m256 x) { max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); return _mm_cvtss_f32(max4); } +static inline float hmin_float_8(__m256 x) { + __m128 min4 = _mm_min_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + min4 = _mm_min_ps( min4, _mm_movehl_ps(min4, min4)); + min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4)); + return _mm_cvtss_f32(min4); +} static inline __m128 hsum_float_4x4(__m128 * accm) { accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2])); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 19ddcea424..86a082230f 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -713,6 +713,94 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) { #endif } +void iqk_quantize_q4_0(const float * x, void * vy, int64_t k) { + const int nb = k / QK4_0; + auto y = (block_q4_0 *)vy; +#ifdef __AVX2__ + static_assert(QK4_0 == 32); + __m256 vx[4], rx[4]; + __m256i ix[4]; + auto v7 = _mm256_set1_ps(7.0f); + auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + vx[k] = _mm256_loadu_ps(x + 8*k); + } + auto max1 = _mm256_max_ps(vx[0], vx[1]); + auto max2 = _mm256_max_ps(vx[2], vx[3]); + auto vmax = _mm256_max_ps(max1, max2); + auto min1 = _mm256_min_ps(vx[0], vx[1]); + auto min2 = _mm256_min_ps(vx[2], vx[3]); + auto vmin = _mm256_min_ps(min1, min2); + float max = hmax_float_8(vmax); + float min = hmin_float_8(vmin); + float amax = std::abs(max); + float amin = std::abs(min); + float d, id; + if (amax > amin) { + d = max / -8; + id = amax > 1e-13f ? 1/d : 0.0f; + } else { + d = min / -8; + id = amin > 1e-13f ? 1/d : 0.0f; + } + auto vid = _mm256_set1_ps(id); + auto vsumqx = _mm256_setzero_ps(); + auto vsumq2 = _mm256_setzero_ps(); + for (int k = 0; k < 4; ++k) { + rx[k] = _mm256_mul_ps(vid, vx[k]); + rx[k] = _mm256_round_ps(rx[k], _MM_ROUND_NEAREST); + rx[k] = _mm256_min_ps(rx[k], v7); + ix[k] = _mm256_cvtps_epi32(rx[k]); + ix[k] = _mm256_add_epi32(ix[k], _mm256_set1_epi32(8)); + auto w = _mm256_mul_ps(vx[k], vx[k]); + auto wr = _mm256_mul_ps(w, rx[k]); + vsumqx = _mm256_fmadd_ps(wr, vx[k], vsumqx); + vsumq2 = _mm256_fmadd_ps(wr, rx[k], vsumq2); + } + auto sumq2 = hsum_float_8(vsumq2); + if (sumq2 > 0) { + auto sumqx = hsum_float_8(vsumqx); + d = sumqx/sumq2; + } + y[ib].d = GGML_FP32_TO_FP16(d); + auto i0 = _mm256_packs_epi32(ix[0], ix[1]); + auto i2 = _mm256_packs_epi32(ix[2], ix[3]); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, perm); + auto q = _mm_or_si128(_mm256_castsi256_si128(i0), _mm_slli_epi16(_mm256_extracti128_si256(i0, 1), 4)); + _mm_storeu_si128((__m128i *)y[ib].qs, q); + x += QK4_0; + } +#else + for (int ib = 0; ib < nb; ++ib) { + float max = 0, amax = 0; + for (int j = 0; j < QK4_0; ++j) { + float ax = std::abs(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + float d = max / -8; + float id = amax > 1e-13f ? 1/d : 0.0f; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_0/2; ++j) { + float v0 = x[j], v1 = x[j+QK4_0/2]; + int i0 = nearest_int(id*v0), i1 = nearest_int(id*v1); + i0 = std::min(i0, 7); + i1 = std::min(i1, 7); + float w0 = v0*v0, w1 = v1*v1; + sumqx += w0*i0*v0 + w1*i1*v1; + sumq2 += w0*i0*i0 + w1*i1*i1; + y[ib].qs[j] = (i0 + 8) | ((i1 + 8) << 4); + } + if (sumq2 > 0) d = sumqx/sumq2; + y[ib].d = GGML_FP32_TO_FP16(d); + x += QK4_0; + } +#endif +} + void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) { const int nb = k / QK8_0; const int nb4 = 4*(nb/4); diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 5f0622615f..6e4cc5870a 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -305,6 +305,7 @@ void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_2_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void iqk_quantize_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);