Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q1_0_g128,
.from_float_ref = (ggml_from_float_t)quantize_row_q1_0_g128_ref,
.vec_dot = vec_dot_q1_0_g128_q8_0,
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
.nrows = 1,
.row_meta_size = 0,
},
Expand Down Expand Up @@ -21701,10 +21705,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
q->data, k->data, v->data, mask ? mask->data : NULL, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
printf("iqk_flash_attn_noalibi returned false for Dk = %ld, Dv = %ld, mask = %p:\n", Dk, Dv, (const void *)mask);
printf(" q(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(q->type), q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
printf(" k(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(k->type), k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
printf(" v(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(v->type), v->ne[0], v->ne[1], v->ne[2], v->ne[3]);

// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
Expand Down
86 changes: 64 additions & 22 deletions ggml/src/iqk/iqk_gemm_1bit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1448,20 +1448,17 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
template <int nrc_y>
static void mul_mat_q1_0_g128_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
constexpr int n4 = QK1_0_G128 / QK8_0;
Q8<nrc_y, block_q8_0> q8(info);
const block_q8_0_x4 * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
}
Q8<nrc_y, block_q8_2_x4> q8(info);
#ifndef HAVE_FANCY_SIMD
__m256i shuffle[4] = {
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000),
_mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404),
_mm256_set_epi64x(0x0b0b0b0b0b0b0b0b, 0x0a0a0a0a0a0a0a0a, 0x0909090909090909, 0x0808080808080808),
_mm256_set_epi64x(0x0f0f0f0f0f0f0f0f, 0x0e0e0e0e0e0e0e0e, 0x0d0d0d0d0d0d0d0d, 0x0c0c0c0c0c0c0c0c),
};
auto mask = _mm256_set1_epi64x(0x8040201008040201);
auto mp1 = _mm256_set1_epi8( 1);
auto mm1 = _mm256_set1_epi8(-1);
#endif
auto mp2 = _mm256_set1_epi8 (2);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK1_0_G128;
__m256i qx[4];
Expand All @@ -1472,27 +1469,33 @@ static void mul_mat_q1_0_g128_q8_0(int n, const void * vx, size_t bx, const Data
for (int ib = 0; ib < nb; ++ib) {
float d = GGML_FP16_TO_FP32(x[ib].d);
auto vd = _mm256_set1_ps(d);
#ifdef HAVE_FANCY_SIMD
auto m32 = (const __mmask32 *)x[ib].qs;
for (int k = 0; k < 4; ++k) {
qx[k] = _mm256_mask_blend_epi8(m32[k], _mm256_setzero_si256(), mp2);
}
#else
auto bits128 = _mm_loadu_si128((const __m128i *)x[ib].qs);
auto bits = MM256_SET_M128I(bits128, bits128);
for (int k = 0; k < 4; ++k) {
qx[k] = _mm256_shuffle_epi8(bits, shuffle[k]);
qx[k] = _mm256_cmpeq_epi8(_mm256_and_si256(qx[k], mask), mask);
qx[k] = _mm256_or_si256(_mm256_and_si256(qx[k], mp1), _mm256_andnot_si256(qx[k], mm1));
qx[k] = _mm256_and_si256(qx[k], mp2);
}
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
for (int k = 0; k < n4; ++k) {
auto qy = _mm256_loadu_si256((const __m256i *)y[iy][ib].qs + k);
#ifdef HAVE_VNNI256
sumi[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), mp1, _mm256_sign_epi8(qy, qx[k]));
#else
sumi[k] = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(mp1, _mm256_sign_epi8(qy, qx[k])));
#endif
auto qy = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
sumi[k] = _mm256_maddubs_epi16(qx[k], qy);
}
sumi[0] = _mm256_madd_epi16(m1, _mm256_packs_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_madd_epi16(m1, _mm256_packs_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_madd_epi16(m1, _mm256_packs_epi32(sumi[0], sumi[2]));
auto dy = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y[iy][ib].d));
sumi[0] = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
sumi[0] = _mm256_madd_epi16(m1, sumi[0]);
auto dy = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d)), 16));
auto sy = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][ib].d+4)));
auto dxy= _mm256_mul_ps(vd, _mm256_set_m128(dy, dy));
sumi[0] = _mm256_sub_epi32(sumi[0], MM256_SET_M128I(sy, _mm_setzero_si128()));
acc[iy] = _mm256_fmadd_ps(dxy, _mm256_cvtepi32_ps(sumi[0]), acc[iy]);
}
}
Expand Down Expand Up @@ -1962,7 +1965,7 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
break;
case GGML_TYPE_Q1_0_G128:
if (ne00 % QK1_0_G128 != 0) return false;
expected_typeB = GGML_TYPE_Q8_0_X4;
expected_typeB = GGML_TYPE_Q8_2_X4;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q1_0_g128_q8_0, funcs);
break;

Expand Down Expand Up @@ -2344,13 +2347,52 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
template <int nrc_y>
static void mul_mat_q1_0_g128_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
Q8<nrc_y, block_q8_0_x4> q8(info);
const uint8x16_t shuffle[8] = {
vcombine_u8(vdup_n_u8( 0), vdup_n_u8( 1)), vcombine_u8(vdup_n_u8( 2), vdup_n_u8( 3)),
vcombine_u8(vdup_n_u8( 4), vdup_n_u8( 5)), vcombine_u8(vdup_n_u8( 6), vdup_n_u8( 7)),
vcombine_u8(vdup_n_u8( 8), vdup_n_u8( 9)), vcombine_u8(vdup_n_u8(10), vdup_n_u8(11)),
vcombine_u8(vdup_n_u8(12), vdup_n_u8(13)), vcombine_u8(vdup_n_u8(14), vdup_n_u8(15)),
};
auto mask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
auto m2 = vdupq_n_s8(2);
auto m1 = vdupq_n_s8(1);
int nb = n / QK1_0_G128;
int8x16_t qx[8];
int32x4_t sumi[4];
for (int ix = 0; ix < nrc_x; ++ix) {
auto x = (const block_q1_0_g128 *)((const char *)vx + ix*bx);
float32x4_t acc[nrc_y] = {};
for (int ib = 0; ib < nb; ++ib) {
auto dx = vdupq_n_f32(GGML_FP16_TO_FP32(x[ib].d));
auto bits = vld1q_u8(x[ib].qs);
for (int k = 0; k < 8; ++k) {
auto val = vqtbl1q_u8(bits, shuffle[k]);
val = vceqq_u8(vandq_u8(val, mask), mask);
qx[k] = vsubq_s8(vandq_s8(val, m2), m1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib].d));
auto vd = vmulq_f32(dx, dy);
auto qy1 = vld1q_s8_x4(q8.y[iy][ib].qs+ 0);
auto qy2 = vld1q_s8_x4(q8.y[iy][ib].qs+64);
for (int k = 0; k < 2; ++k) {
sumi[k+0] = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2*k+0], qy1.val[2*k+0]), qx[2*k+1], qy1.val[2*k+1]);
sumi[k+2] = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2*k+4], qy2.val[2*k+0]), qx[2*k+5], qy2.val[2*k+1]);
}
sumi[0] = vpaddq_s32(sumi[0], sumi[1]);
sumi[2] = vpaddq_s32(sumi[2], sumi[3]);
sumi[0] = vpaddq_s32(sumi[0], sumi[2]);
acc[iy] = vfmaq_f32(acc[iy], vd, vcvtq_f32_s32(sumi[0]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
float s;
vec_dot_q1_0_g128_q8_0(n, &s, 0, x, bx, q8.y[iy], 0, 1);
info.store(ix, iy, s);
info.store(ix, iy, vaddvq_f32(acc[iy]));
}
//for (int iy = 0; iy < nrc_y; ++iy) {
// float s;
// vec_dot_q1_0_g128_q8_0(n, &s, 0, x, bx, q8.y[iy], 0, 1);
// info.store(ix, iy, s);
//}
}
}

Expand Down