@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607
607
assert (k % QK == 0 );
608
608
609
609
const int nb = k / QK ;
610
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
610
611
611
- float * restrict pm = (float * ) ( y );
612
- float * restrict pd = (float * ) ( pm + nb );
613
- uint8_t * restrict pb = (uint8_t * ) ( pd + nb );
612
+ uint8_t * restrict pd = (( uint8_t * )y + 0 * bs );
613
+ uint8_t * restrict pm = (( uint8_t * )y + 0 * bs + sizeof ( float ) );
614
+ uint8_t * restrict pb = (( uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
614
615
615
616
uint8_t pp [QK /2 ];
616
617
@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627
628
const float d = (max - min ) / ((1 << 4 ) - 1 );
628
629
const float id = d ? 1.0f /d : 0.0f ;
629
630
630
- pm [i ] = min ;
631
- pd [i ] = d ;
631
+ * (float * )pm = min ;
632
+ * (float * )pd = d ;
633
+ pm += bs ;
634
+ pd += bs ;
632
635
633
636
for (int l = 0 ; l < QK ; l += 2 ) {
634
637
const float v0 = (x [i * QK + l + 0 ] - min )* id ;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643
646
pp [l /2 ] = vi0 | (vi1 << 4 );
644
647
}
645
648
646
- memcpy (pb + i * QK /2 , pp , sizeof (pp ));
649
+ memcpy (pb , pp , sizeof (pp ));
650
+ pb += bs ;
647
651
}
648
652
}
649
653
@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687
691
assert (k % QK == 0 );
688
692
689
693
const int nb = k / QK ;
694
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
690
695
691
- const float * restrict pm = (const float * ) ( x );
692
- const float * restrict pd = (const float * ) ( pm + nb );
693
- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
696
+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
697
+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
698
+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
694
699
695
700
for (int i = 0 ; i < nb ; i ++ ) {
696
- const float m = pm [ i ] ;
697
- const float d = pd [ i ] ;
701
+ const float d = * ( const float * ) ( pd + i * bs ) ;
702
+ const float m = * ( const float * ) ( pm + i * bs ) ;
698
703
699
- const uint8_t * restrict pp = pb + i * QK / 2 ;
704
+ const uint8_t * restrict pp = pb + i * bs ;
700
705
701
706
for (int l = 0 ; l < QK ; l += 2 ) {
702
707
const uint8_t vi = pp [l /2 ];
@@ -1584,14 +1589,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1584
1589
inline static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict x , const void * restrict y ) {
1585
1590
const int nb = n / QK ;
1586
1591
1587
- const float * restrict pm0 = (const float * ) x ;
1588
- const float * restrict pm1 = (const float * ) y ;
1592
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
1593
+
1594
+ const uint8_t * restrict pd0 = ((const uint8_t * )x + 0 * bs );
1595
+ const uint8_t * restrict pd1 = ((const uint8_t * )y + 0 * bs );
1589
1596
1590
- const float * restrict pd0 = (const float * ) ( pm0 + nb );
1591
- const float * restrict pd1 = (const float * ) ( pm1 + nb );
1597
+ const uint8_t * restrict pm0 = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
1598
+ const uint8_t * restrict pm1 = (( const uint8_t * )y + 0 * bs + sizeof ( float ) );
1592
1599
1593
- const uint8_t * restrict pb0 = (const uint8_t * ) ( pd0 + nb );
1594
- const uint8_t * restrict pb1 = (const uint8_t * ) ( pd1 + nb );
1600
+ const uint8_t * restrict pb0 = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
1601
+ const uint8_t * restrict pb1 = (( const uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
1595
1602
1596
1603
float sumf = 0.0 ;
1597
1604
@@ -1604,14 +1611,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1604
1611
1605
1612
// Main loop
1606
1613
for (int i = 0 ; i < nb ; ++ i ) {
1607
- const float * m0 = (const float * ) (pm0 + i );
1608
- const float * m1 = (const float * ) (pm1 + i );
1614
+ const float * m0 = (const float * ) (pm0 + i * bs );
1615
+ const float * m1 = (const float * ) (pm1 + i * bs );
1609
1616
1610
- const float * d0 = (const float * ) (pd0 + i );
1611
- const float * d1 = (const float * ) (pd1 + i );
1617
+ const float * d0 = (const float * ) (pd0 + i * bs );
1618
+ const float * d1 = (const float * ) (pd1 + i * bs );
1612
1619
1613
- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1614
- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1620
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1621
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1615
1622
1616
1623
const __m256 d0v = _mm256_broadcast_ss ( d0 );
1617
1624
const __m256 d1v = _mm256_broadcast_ss ( d1 );
@@ -1677,14 +1684,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1677
1684
#else
1678
1685
// scalar
1679
1686
for (int i = 0 ; i < nb ; i ++ ) {
1680
- const float m0 = pm0 [ i ] ;
1681
- const float m1 = pm1 [ i ] ;
1687
+ const float * m0 = ( const float * ) ( pm0 + i * bs ) ;
1688
+ const float * m1 = ( const float * ) ( pm1 + i * bs ) ;
1682
1689
1683
- const float d0 = pd0 [ i ] ;
1684
- const float d1 = pd1 [ i ] ;
1690
+ const float * d0 = ( const float * ) ( pd0 + i * bs ) ;
1691
+ const float * d1 = ( const float * ) ( pd1 + i * bs ) ;
1685
1692
1686
- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1687
- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1693
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1694
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1688
1695
1689
1696
for (int j = 0 ; j < QK /2 ; j ++ ) {
1690
1697
const uint8_t v0 = p0 [j ];
0 commit comments