@@ -359,6 +359,45 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
359359
360360#define QK 32
361361
362+ // AVX routines provided by GH user Const-me
363+ // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
364+ #if __AVX2__
365+ // Unpack 32 4-bit fields into 32 bytes
366+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
367+ inline __m256i bytesFromNibbles ( const uint8_t * rsi )
368+ {
369+ // Load 16 bytes from memory
370+ __m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
371+
372+ // Expand bytes into uint16_t values
373+ __m256i bytes = _mm256_cvtepu8_epi16 ( tmp );
374+
375+ // Unpack values into individual bytes
376+ const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
377+ __m256i high = _mm256_andnot_si256 ( lowMask , bytes );
378+ __m256i low = _mm256_and_si256 ( lowMask , bytes );
379+ high = _mm256_slli_epi16 ( high , 4 );
380+ bytes = _mm256_or_si256 ( low , high );
381+ return bytes ;
382+ }
383+
384+ inline __m128i packNibbles ( __m256i bytes )
385+ {
386+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
387+ const __m256i lowByte = _mm256_set1_epi16 ( 0xFF );
388+ __m256i high = _mm256_andnot_si256 ( lowByte , bytes );
389+ __m256i low = _mm256_and_si256 ( lowByte , bytes );
390+ high = _mm256_srli_epi16 ( high , 4 );
391+ bytes = _mm256_or_si256 ( low , high );
392+
393+ // Compress uint16_t lanes into bytes
394+ __m128i r0 = _mm256_castsi256_si128 ( bytes );
395+ __m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
396+ return _mm_packus_epi16 ( r0 , r1 );
397+ }
398+ #endif
399+
400+
362401// method 5
363402// blocks of QK elements
364403// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -414,6 +453,77 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
414453#else
415454#error "not implemented for QK"
416455#endif
456+ #elif defined(__AVX2__ )
457+ #if QK == 32
458+ for (int i = 0 ; i < nb ; i ++ ) {
459+ // Load elements into 4 AVX vectors
460+ __m256 v0 = _mm256_loadu_ps ( x );
461+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
462+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
463+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
464+ x += 32 ;
465+
466+ // Compute max(abs(e)) for the block
467+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
468+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
469+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
470+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
471+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
472+
473+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
474+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
475+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
476+ const float maxScalar = _mm_cvtss_f32 ( max4 );
477+
478+ // Quantize these floats
479+ const float d = maxScalar / 7.0f ;
480+ * (float * )pd = d ;
481+ pd += bs ;
482+ const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f ;
483+ const __m256 mul = _mm256_set1_ps ( id );
484+
485+ // Apply the multiplier
486+ v0 = _mm256_mul_ps ( v0 , mul );
487+ v1 = _mm256_mul_ps ( v1 , mul );
488+ v2 = _mm256_mul_ps ( v2 , mul );
489+ v3 = _mm256_mul_ps ( v3 , mul );
490+
491+ // Round to nearest integer
492+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
493+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
494+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
495+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
496+
497+ // Convert floats to integers
498+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
499+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
500+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
501+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
502+
503+ // Convert int32 to int16
504+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
505+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
506+ // Convert int16 to int8
507+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
508+
509+ // We got our precious signed bytes, but the order is now wrong
510+ // These AVX2 pack instructions process 16-byte pieces independently
511+ // The following instruction is fixing the order
512+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
513+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
514+
515+ // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
516+ const __m256i off = _mm256_set1_epi8 ( 8 );
517+ i0 = _mm256_add_epi8 ( i0 , off );
518+
519+ // Compress the vector into 4 bit/value, and store
520+ __m128i res = packNibbles ( i0 );
521+ _mm_storeu_si128 ( ( __m128i * )pb , res );
522+ pb += bs ;
523+ }
524+ #else
525+ #error "not implemented for QK"
526+ #endif
417527#elif defined(__wasm_simd128__ )
418528#if QK == 32
419529 for (int i = 0 ; i < nb ; i ++ ) {
@@ -1285,6 +1395,61 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
12851395#else
12861396#error "not implemented for QK"
12871397#endif
1398+ #elif defined(__AVX2__ )
1399+ #if QK == 32
1400+ const size_t countBlocks = nb ;
1401+
1402+ // Initialize accumulator with zeros
1403+ __m256 acc = _mm256_setzero_ps ();
1404+
1405+ // Main loop
1406+ for (int i = 0 ; i < nb ; ++ i ) {
1407+ const float * d0_0 = (const float * ) (pd0 + i * bs );
1408+ const float * d1_0 = (const float * ) (pd1 + i * bs );
1409+
1410+ const uint8_t * restrict p0 = pb0 + i * bs ;
1411+ const uint8_t * restrict p1 = pb1 + i * bs ;
1412+
1413+ // Compute combined scale for the block
1414+ const __m256 scale = _mm256_mul_ps ( _mm256_broadcast_ss ( d0_0 ), _mm256_broadcast_ss ( d1_0 ) );
1415+
1416+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1417+ __m256i bx = bytesFromNibbles ( p0 );
1418+ __m256i by = bytesFromNibbles ( p1 );
1419+
1420+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1421+ const __m256i off = _mm256_set1_epi8 ( 8 );
1422+ bx = _mm256_sub_epi8 ( bx , off );
1423+ by = _mm256_sub_epi8 ( by , off );
1424+
1425+ // Sign-extend first 16 signed bytes into int16_t
1426+ __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
1427+ __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
1428+ // Compute products of int16_t integers, add pairwise
1429+ __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
1430+
1431+ // Sign-extend last 16 signed bytes into int16_t vectors
1432+ x16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
1433+ y16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
1434+ // Accumulate products of int16_t integers
1435+ i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16 , y16 ) );
1436+
1437+ // Convert int32_t to float
1438+ __m256 p = _mm256_cvtepi32_ps ( i32 );
1439+ // Apply the scale, and accumulate
1440+ acc = _mm256_fmadd_ps ( scale , p , acc );
1441+ }
1442+
1443+ // Return horizontal sum of the acc vector
1444+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
1445+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
1446+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
1447+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
1448+
1449+ sumf = _mm_cvtss_f32 ( res );
1450+ #else
1451+ #error "not implemented for QK"
1452+ #endif
12881453#elif defined(__wasm_simd128__ )
12891454#if QK == 32
12901455 // wasm simd
0 commit comments