@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607607 assert (k % QK == 0 );
608608
609609 const int nb = k / QK ;
610+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
610611
611- float * restrict pm = (float * ) ( y );
612- float * restrict pd = (float * ) ( pm + nb );
613- uint8_t * restrict pb = (uint8_t * ) ( pd + nb );
612+ uint8_t * restrict pd = (( uint8_t * )y + 0 * bs );
613+ uint8_t * restrict pm = (( uint8_t * )y + 0 * bs + sizeof ( float ) );
614+ uint8_t * restrict pb = (( uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
614615
615616 uint8_t pp [QK /2 ];
616617
@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627628 const float d = (max - min ) / ((1 << 4 ) - 1 );
628629 const float id = d ? 1.0f /d : 0.0f ;
629630
630- pm [i ] = min ;
631- pd [i ] = d ;
631+ * (float * )pm = min ;
632+ * (float * )pd = d ;
633+ pm += bs ;
634+ pd += bs ;
632635
633636 for (int l = 0 ; l < QK ; l += 2 ) {
634637 const float v0 = (x [i * QK + l + 0 ] - min )* id ;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643646 pp [l /2 ] = vi0 | (vi1 << 4 );
644647 }
645648
646- memcpy (pb + i * QK /2 , pp , sizeof (pp ));
649+ memcpy (pb , pp , sizeof (pp ));
650+ pb += bs ;
647651 }
648652}
649653
@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687691 assert (k % QK == 0 );
688692
689693 const int nb = k / QK ;
694+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
690695
691- const float * restrict pm = (const float * ) ( x );
692- const float * restrict pd = (const float * ) ( pm + nb );
693- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
696+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
697+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
698+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
694699
695700 for (int i = 0 ; i < nb ; i ++ ) {
696- const float m = pm [ i ] ;
697- const float d = pd [ i ] ;
701+ const float d = * ( const float * ) ( pd + i * bs ) ;
702+ const float m = * ( const float * ) ( pm + i * bs ) ;
698703
699- const uint8_t * restrict pp = pb + i * QK / 2 ;
704+ const uint8_t * restrict pp = pb + i * bs ;
700705
701706 for (int l = 0 ; l < QK ; l += 2 ) {
702707 const uint8_t vi = pp [l /2 ];
@@ -1584,28 +1589,109 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
15841589inline static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict x , const void * restrict y ) {
15851590 const int nb = n / QK ;
15861591
1587- const float * restrict pm0 = (const float * ) x ;
1588- const float * restrict pm1 = (const float * ) y ;
1592+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
15891593
1590- const float * restrict pd0 = (const float * ) (pm0 + nb );
1591- const float * restrict pd1 = (const float * ) (pm1 + nb );
1594+ const uint8_t * restrict pd0 = ((const uint8_t * )x + 0 * bs );
1595+ const uint8_t * restrict pd1 = ((const uint8_t * )y + 0 * bs );
1596+
1597+ const uint8_t * restrict pm0 = ((const uint8_t * )x + 0 * bs + sizeof (float ));
1598+ const uint8_t * restrict pm1 = ((const uint8_t * )y + 0 * bs + sizeof (float ));
15921599
1593- const uint8_t * restrict pb0 = (const uint8_t * ) ( pd0 + nb );
1594- const uint8_t * restrict pb1 = (const uint8_t * ) ( pd1 + nb );
1600+ const uint8_t * restrict pb0 = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
1601+ const uint8_t * restrict pb1 = (( const uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
15951602
15961603 float sumf = 0.0 ;
15971604
1598- #if 1
1605+ #if defined(__AVX2__ )
1606+ #if QK == 32
1607+ // Initialize accumulator with zeros
1608+ __m256 acc = _mm256_setzero_ps ();
1609+ // Accumulator for constant offsets
1610+ float acc_offset = 0.0f ;
1611+
1612+ // Main loop
1613+ for (int i = 0 ; i < nb ; ++ i ) {
1614+ const float * m0 = (const float * ) (pm0 + i * bs );
1615+ const float * m1 = (const float * ) (pm1 + i * bs );
1616+
1617+ const float * d0 = (const float * ) (pd0 + i * bs );
1618+ const float * d1 = (const float * ) (pd1 + i * bs );
1619+
1620+ const uint8_t * restrict p0 = pb0 + i * bs ;
1621+ const uint8_t * restrict p1 = pb1 + i * bs ;
1622+
1623+ const __m256 d0v = _mm256_broadcast_ss ( d0 );
1624+ const __m256 d1v = _mm256_broadcast_ss ( d1 );
1625+ const __m256 m0v = _mm256_broadcast_ss ( m0 );
1626+ const __m256 m1v = _mm256_broadcast_ss ( m1 );
1627+
1628+
1629+ // Compute combined scale for the block
1630+ const __m256 scale_01 = _mm256_mul_ps ( d0v , d1v );
1631+
1632+ // Compute cross scales for the block
1633+ const __m256 scale_0 = _mm256_mul_ps ( d0v , m1v );
1634+ const __m256 scale_1 = _mm256_mul_ps ( m0v , d1v );
1635+ const __m256 cross_scales = _mm256_blend_ps ( scale_0 , scale_1 , 0b10101010 );
1636+
1637+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1638+ __m256i bx = bytesFromNibbles ( p0 );
1639+ __m256i by = bytesFromNibbles ( p1 );
1640+
1641+ // Now we have a vector with bytes in [ 0 .. 15 ] interval.
1642+
1643+ // Sign-extend first 16 signed bytes into int16_t
1644+ __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
1645+ __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
1646+ // Compute products of int16_t integers, add pairwise
1647+ __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
1648+
1649+ // Sign-extend last 16 signed bytes into int16_t vectors
1650+ __m256i x16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
1651+ __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
1652+ // Accumulate products of int16_t integers
1653+ i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16_h , y16_h ) );
1654+
1655+ // compute sums of unsigned bytes in bx, by in blocks of 8.
1656+ // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
1657+ // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
1658+ // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
1659+ __m256i xsumi = _mm256_sad_epu8 ( bx , _mm256_setzero_si256 () );
1660+ __m256i ysumi = _mm256_sad_epu8 ( by , _mm256_setzero_si256 () );
1661+ __m256i sumsi = _mm256_or_si256 ( xsumi , _mm256_slli_si256 ( ysumi , 4 ) );
1662+ __m256 sums = _mm256_cvtepi32_ps ( sumsi );
1663+
1664+ // Convert int32_t to float
1665+ __m256 p = _mm256_cvtepi32_ps ( i32 );
1666+ // Apply the scale, and accumulate
1667+ // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
1668+ acc = _mm256_fmadd_ps ( scale_01 , p , acc );
1669+ acc = _mm256_fmadd_ps ( cross_scales , sums , acc );
1670+ // acc_offset += m0*m1 (for each entry in the block)
1671+ acc_offset += (* m0 )* (* m1 );
1672+ }
1673+
1674+ // Return horizontal sum of the acc vector
1675+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
1676+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
1677+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
1678+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
1679+
1680+ sumf = _mm_cvtss_f32 ( res ) + acc_offset * QK ;
1681+ #else
1682+ #error "not implemented for QK"
1683+ #endif
1684+ #else
15991685 // scalar
16001686 for (int i = 0 ; i < nb ; i ++ ) {
1601- const float m0 = pm0 [ i ] ;
1602- const float m1 = pm1 [ i ] ;
1687+ const float m0 = * ( const float * ) ( pm0 + i * bs ) ;
1688+ const float m1 = * ( const float * ) ( pm1 + i * bs ) ;
16031689
1604- const float d0 = pd0 [ i ] ;
1605- const float d1 = pd1 [ i ] ;
1690+ const float d0 = * ( const float * ) ( pd0 + i * bs ) ;
1691+ const float d1 = * ( const float * ) ( pd1 + i * bs ) ;
16061692
1607- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1608- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1693+ const uint8_t * restrict p0 = pb0 + i * bs ;
1694+ const uint8_t * restrict p1 = pb1 + i * bs ;
16091695
16101696 for (int j = 0 ; j < QK /2 ; j ++ ) {
16111697 const uint8_t v0 = p0 [j ];
@@ -1839,16 +1925,17 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
18391925 assert (n % QK == 0 );
18401926
18411927 const int nb = n / QK ;
1928+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
18421929
1843- const float * restrict pm = (const float * ) ( x );
1844- const float * restrict pd = (const float * ) ( pm + nb );
1845- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
1930+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
1931+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ));
1932+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
18461933
18471934 for (int i = 0 ; i < nb ; i ++ ) {
1848- const float m = pm [ i ] ;
1849- const float d = pd [ i ] ;
1935+ const float d = * ( const float * ) ( pd + i * bs ) ;
1936+ const float m = * ( const float * ) ( pm + i * bs ) ;
18501937
1851- const uint8_t * restrict pp = pb + i * QK / 2 ;
1938+ const uint8_t * restrict pp = pb + i * bs ;
18521939
18531940 for (int l = 0 ; l < QK ; l += 2 ) {
18541941 const uint8_t vi = pp [l /2 ];
0 commit comments