Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion ggml/src/iqk/iqk_gemm_iquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,9 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
auto smask = _mm256_set1_epi64x(0x8040201008040201);
auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
auto m4 = _mm256_set1_epi8(4);
#ifndef HAVE_VNNI256
auto m1 = _mm256_set1_epi16(1);
#endif
#endif
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
Expand All @@ -1692,7 +1694,7 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
iq3xxs_grid[iq3[ibl].qs[32*ib+27]], iq3xxs_grid[iq3[ibl].qs[32*ib+26]], iq3xxs_grid[iq3[ibl].qs[32*ib+25]], iq3xxs_grid[iq3[ibl].qs[32*ib+24]]);
auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib);
auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
#ifdef HAVE_FANCY_SIMD
#ifdef HAVE_VNNI256
scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
#else
scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
Expand Down Expand Up @@ -1729,10 +1731,17 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
#ifdef HAVE_VNNI256
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_sign_epi8(y, s1));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_sign_epi8(y, s2));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_sign_epi8(y, s3));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_sign_epi8(y, s4));
#else
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
#endif
auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
Expand Down Expand Up @@ -1808,21 +1817,33 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales));
}
#else
#ifdef HAVE_VNNI256
auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib]));
#else
auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib]));
auto scales = _mm256_unpacklo_epi16(scales16, scales16);
#endif
auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_VNNI256
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales));
#else
sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1)));
sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2)));
sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3)));
sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
#endif
}
#endif
}
Expand Down