Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
new_block(i);
}
inline void new_block(int i) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
}
Expand Down Expand Up @@ -1125,6 +1128,64 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
}
}

template <int nrc_y>
static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;

Q8<nrc_y> q8(info);

DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx);

__m512 accd[2*nrc_y];

for (int ix = 0; ix < nrc_x; ix += 2) {

for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();

deq1.new_row(ix+0);
deq2.new_row(ix+1);

for (int i = 0; i < nb; ++i) {

deq1.new_block(i);
deq2.new_block(i);
float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows

for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
auto sumi_scales_512 = _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales_256, 0);
auto q8q = q8.load_quants64(iy, i, 0);
auto sumi_1 = _mm512_dpbusd_epi32(sumi_scales_512, deq1.bits.values[0], q8q);
auto sumi_2 = _mm512_dpbusd_epi32(sumi_scales_512, deq2.bits.values[0], q8q);
q8q = q8.load_quants64(iy, i, 1);
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[1], q8q);
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[1], q8q);
q8q = q8.load_quants64(iy, i, 2);
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[2], q8q);
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[2], q8q);
q8q = q8.load_quants64(iy, i, 3);
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q);
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
// The scale is supposed to be per per tensor, so we can use the same scale
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
// Leaving this here just in case ternary models start using per row scales
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
}

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
}

}
}

template <typename Dequantizer, int nrc_y>
static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
Expand Down Expand Up @@ -3589,7 +3650,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ2_TN:
assert (ne00 % QK_K == 0);
#ifdef HAVE_FANCY_SIMD
MulMat::set_functions<DequantizerIQ2TN>(mm);
//MulMat::set_functions<DequantizerIQ2TN>(mm);
mm.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<DequantizerIQ2TN>;
//mm.funcs[0] = mul_mat_iq2tn_q8_K_AVX512<1>;
mm.funcs[1] = mul_mat_iq2tn_q8_K_AVX512<2>;
mm.funcs[2] = mul_mat_iq2tn_q8_K_AVX512<3>;
mm.funcs[3] = mul_mat_iq2tn_q8_K_AVX512<4>;
mm.funcs[4] = mul_mat_iq2tn_q8_K_AVX512<5>;
mm.funcs[5] = mul_mat_iq2tn_q8_K_AVX512<6>;
mm.funcs[6] = mul_mat_iq2tn_q8_K_AVX512<7>;
mm.funcs[7] = mul_mat_iq2tn_q8_K_AVX512<8>;
#else
mm.funcs[0] = mul_mat_iq2tn_q8_K<1>;
mm.funcs[1] = mul_mat_iq2tn_q8_K<2>;
Expand Down