Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, in
}

void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_0_ref(x, y, k);
iqk_quantize_q4_0(x, y, k);
//quantize_row_q4_0_ref(x, y, k);
}


Expand Down
6 changes: 6 additions & 0 deletions ggml/src/iqk/iqk_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ static inline float hmax_float_8(__m256 x) {
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
return _mm_cvtss_f32(max4);
}
static inline float hmin_float_8(__m256 x) {
__m128 min4 = _mm_min_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
min4 = _mm_min_ps( min4, _mm_movehl_ps(min4, min4));
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4));
return _mm_cvtss_f32(min4);
}

static inline __m128 hsum_float_4x4(__m128 * accm) {
accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2]));
Expand Down
88 changes: 88 additions & 0 deletions ggml/src/iqk/iqk_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,94 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
#endif
}

void iqk_quantize_q4_0(const float * x, void * vy, int64_t k) {
const int nb = k / QK4_0;
auto y = (block_q4_0 *)vy;
#ifdef __AVX2__
static_assert(QK4_0 == 32);
__m256 vx[4], rx[4];
__m256i ix[4];
auto v7 = _mm256_set1_ps(7.0f);
auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int ib = 0; ib < nb; ++ib) {
for (int k = 0; k < 4; ++k) {
vx[k] = _mm256_loadu_ps(x + 8*k);
}
auto max1 = _mm256_max_ps(vx[0], vx[1]);
auto max2 = _mm256_max_ps(vx[2], vx[3]);
auto vmax = _mm256_max_ps(max1, max2);
auto min1 = _mm256_min_ps(vx[0], vx[1]);
auto min2 = _mm256_min_ps(vx[2], vx[3]);
auto vmin = _mm256_min_ps(min1, min2);
float max = hmax_float_8(vmax);
float min = hmin_float_8(vmin);
float amax = std::abs(max);
float amin = std::abs(min);
float d, id;
if (amax > amin) {
d = max / -8;
id = amax > 1e-13f ? 1/d : 0.0f;
} else {
d = min / -8;
id = amin > 1e-13f ? 1/d : 0.0f;
}
auto vid = _mm256_set1_ps(id);
auto vsumqx = _mm256_setzero_ps();
auto vsumq2 = _mm256_setzero_ps();
for (int k = 0; k < 4; ++k) {
rx[k] = _mm256_mul_ps(vid, vx[k]);
rx[k] = _mm256_round_ps(rx[k], _MM_ROUND_NEAREST);
rx[k] = _mm256_min_ps(rx[k], v7);
ix[k] = _mm256_cvtps_epi32(rx[k]);
ix[k] = _mm256_add_epi32(ix[k], _mm256_set1_epi32(8));
auto w = _mm256_mul_ps(vx[k], vx[k]);
auto wr = _mm256_mul_ps(w, rx[k]);
vsumqx = _mm256_fmadd_ps(wr, vx[k], vsumqx);
vsumq2 = _mm256_fmadd_ps(wr, rx[k], vsumq2);
}
auto sumq2 = hsum_float_8(vsumq2);
if (sumq2 > 0) {
auto sumqx = hsum_float_8(vsumqx);
d = sumqx/sumq2;
}
y[ib].d = GGML_FP32_TO_FP16(d);
auto i0 = _mm256_packs_epi32(ix[0], ix[1]);
auto i2 = _mm256_packs_epi32(ix[2], ix[3]);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, perm);
auto q = _mm_or_si128(_mm256_castsi256_si128(i0), _mm_slli_epi16(_mm256_extracti128_si256(i0, 1), 4));
_mm_storeu_si128((__m128i *)y[ib].qs, q);
x += QK4_0;
}
#else
for (int ib = 0; ib < nb; ++ib) {
float max = 0, amax = 0;
for (int j = 0; j < QK4_0; ++j) {
float ax = std::abs(x[j]);
if (ax > amax) {
amax = ax; max = x[j];
}
}
float d = max / -8;
float id = amax > 1e-13f ? 1/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_0/2; ++j) {
float v0 = x[j], v1 = x[j+QK4_0/2];
int i0 = nearest_int(id*v0), i1 = nearest_int(id*v1);
i0 = std::min(i0, 7);
i1 = std::min(i1, 7);
float w0 = v0*v0, w1 = v1*v1;
sumqx += w0*i0*v0 + w1*i1*v1;
sumq2 += w0*i0*i0 + w1*i1*i1;
y[ib].qs[j] = (i0 + 8) | ((i1 + 8) << 4);
}
if (sumq2 > 0) d = sumqx/sumq2;
y[ib].d = GGML_FP32_TO_FP16(d);
x += QK4_0;
}
#endif
}

void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) {
const int nb = k / QK8_0;
const int nb4 = 4*(nb/4);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/iqk/iqk_quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_2_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void iqk_quantize_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
Expand Down