@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657
657
#define QK8_0 32
658
658
typedef struct {
659
659
float d ; // delta
660
+ float s ; // d * sum(qs[i])
660
661
int8_t qs [QK8_0 ]; // quants
661
662
} block_q8_0 ;
662
- static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
+ static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
664
664
665
665
666
// reference implementation for deterministic creation of model files
@@ -1299,12 +1300,38 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1299
1300
1300
1301
y [i ].d = d ;
1301
1302
1303
+ int sum = 0 ;
1302
1304
for (int l = 0 ; l < QK8_0 ; ++ l ) {
1303
1305
const float v = x [i * QK8_0 + l ]* id ;
1304
1306
y [i ].qs [l ] = roundf (v );
1305
- }
1306
- }
1307
+ sum += y [i ].qs [l ];
1308
+ }
1309
+ y [i ].s = d * sum ;
1310
+ }
1311
+ }
1312
+
1313
+ #ifdef __AVX2__
1314
+ // There is no better way of doing this?
1315
+ // I guess not, AVX is not very good at horizontal sums.
1316
+ // The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317
+ // faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318
+ // keeping the original.
1319
+ // TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320
+ //static inline float horizontal_sum(__m256i a) {
1321
+ // __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322
+ // __m256i sum = _mm256_add_epi32(a, b);
1323
+ // __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324
+ // sum = _mm256_add_epi32(sum, hi);
1325
+ // return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326
+ //}
1327
+ static inline float horizontal_sum (__m256i a ) {
1328
+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (a ), _mm256_extracti128_si256 (a , 1 ));
1329
+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1330
+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1331
+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1332
+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1307
1333
}
1334
+ #endif
1308
1335
1309
1336
static void quantize_row_q8_0 (const float * restrict x , void * restrict vy , int k ) {
1310
1337
assert (k % QK8_0 == 0 );
@@ -1332,6 +1359,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1332
1359
1333
1360
y [i ].d = d ;
1334
1361
1362
+ int32x4_t accv = vdupq_n_s32 (0 );
1363
+
1335
1364
for (int l = 0 ; l < 8 ; l ++ ) {
1336
1365
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1337
1366
const int32x4_t vi = vcvtnq_s32_f32 (v );
@@ -1340,7 +1369,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1340
1369
y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1341
1370
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1342
1371
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1372
+
1373
+ accv = vaddq_s32 (accv , vi );
1343
1374
}
1375
+ int32_t sum = vaddvq_s32 (accv );
1376
+ y [i ].s = d * sum ;
1344
1377
}
1345
1378
#elif defined(__AVX2__ ) || defined(__AVX__ )
1346
1379
for (int i = 0 ; i < nb ; i ++ ) {
@@ -1388,6 +1421,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1388
1421
__m256i i3 = _mm256_cvtps_epi32 ( v3 );
1389
1422
1390
1423
#if defined(__AVX2__ )
1424
+
1425
+ // Compute the sum of the quants and set y[i].s
1426
+ y [i ].s = d * horizontal_sum (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
1427
+
1391
1428
// Convert int32 to int16
1392
1429
i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
1430
i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1467,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1430
1467
// scalar
1431
1468
quantize_row_q8_0_reference (x , y , k );
1432
1469
#endif
1470
+ #if defined __AVX__
1471
+ // TODO: vectorize this
1472
+ for (int i = 0 ; i < nb ; ++ i ) {
1473
+ int sum = 0 ;
1474
+ for (int l = 0 ; l < QK8_0 ; ++ l ) sum += y [i ].qs [l ];
1475
+ y [i ].s = y [i ].d * sum ;
1476
+ }
1477
+ #endif
1433
1478
}
1434
1479
1435
1480
static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2372,14 +2417,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2372
2417
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2373
2418
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2374
2419
2420
+ float sum8 = 0 ;
2421
+
2375
2422
for (int i = 0 ; i < nb ; i += 2 ) {
2376
2423
const block_q4_0 * restrict x0 = & x [i + 0 ];
2377
2424
const block_q4_0 * restrict x1 = & x [i + 1 ];
2378
2425
const block_q8_0 * restrict y0 = & y [i + 0 ];
2379
2426
const block_q8_0 * restrict y1 = & y [i + 1 ];
2380
2427
2428
+ sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2429
+
2381
2430
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2382
- const int8x16_t s8b = vdupq_n_s8 (0x8 );
2383
2431
2384
2432
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2385
2433
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2390,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2390
2438
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2391
2439
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2392
2440
2393
- // sub 8
2394
- const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2395
- const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2396
- const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2397
- const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2398
-
2399
2441
// load y
2400
2442
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2401
2443
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2410,21 +2452,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2410
2452
2411
2453
#if defined(__ARM_FEATURE_DOTPROD )
2412
2454
// dot product into int32x4_t
2413
- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2414
- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2455
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2456
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2415
2457
2416
2458
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2417
2459
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
2418
2460
#else
2419
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2420
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2421
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2422
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2461
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2462
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2463
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2464
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
2423
2465
2424
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2425
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2426
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2427
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2466
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2467
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2468
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2469
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2428
2470
2429
2471
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2430
2472
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2478
#endif
2437
2479
}
2438
2480
2439
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2481
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
2440
2482
#elif defined(__AVX2__ )
2441
2483
// Initialize accumulator with zeros
2442
2484
__m256 acc = _mm256_setzero_ps ();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2569
2611
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2570
2612
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2571
2613
2614
+ float summs = 0 ;
2615
+
2572
2616
for (int i = 0 ; i < nb ; i += 2 ) {
2573
2617
const block_q4_1 * restrict x0 = & x [i + 0 ];
2574
2618
const block_q4_1 * restrict x1 = & x [i + 1 ];
2575
2619
const block_q8_0 * restrict y0 = & y [i + 0 ];
2576
2620
const block_q8_0 * restrict y1 = & y [i + 1 ];
2577
2621
2622
+ summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2623
+
2578
2624
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2579
2625
2580
2626
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
@@ -2598,17 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2598
2644
const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2599
2645
const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2600
2646
2601
- const int16x8_t s0i = vaddq_s16 (
2602
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2603
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2604
-
2605
- const int16x8_t s1i = vaddq_s16 (
2606
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2607
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2608
-
2609
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2610
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2611
-
2612
2647
#if defined(__ARM_FEATURE_DOTPROD )
2613
2648
// dot product into int32x4_t
2614
2649
const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
@@ -2637,24 +2672,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
2672
#endif
2638
2673
}
2639
2674
2640
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2675
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
2641
2676
#elif defined(__AVX2__ )
2642
2677
// Initialize accumulator with zeros
2643
2678
__m256 acc = _mm256_setzero_ps ();
2644
2679
2680
+ float summs = 0 ;
2681
+
2645
2682
// Main loop
2646
2683
for (int i = 0 ; i < nb ; ++ i ) {
2647
2684
const float * d0 = & x [i ].d ;
2648
2685
const float * d1 = & y [i ].d ;
2649
- const float * m0 = & x [i ].m ;
2686
+ //const float * m0 = &x[i].m;
2687
+
2688
+ summs += x [i ].m * y [i ].s ;
2650
2689
2651
2690
const __m256 d0v = _mm256_broadcast_ss ( d0 );
2652
2691
const __m256 d1v = _mm256_broadcast_ss ( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss ( m0 );
2654
2692
2655
2693
// Compute combined scales
2656
2694
const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2658
2695
2659
2696
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
2697
const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2676,15 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2676
2713
2677
2714
// Accumulate d0*d1*x*y
2678
2715
acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2679
-
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2685
-
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2688
2716
}
2689
2717
2690
2718
// Return horizontal sum of the acc vector
@@ -2693,7 +2721,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2693
2721
res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2694
2722
res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2695
2723
2696
- sumf = _mm_cvtss_f32 ( res );
2724
+ sumf = _mm_cvtss_f32 ( res ) + summs ;
2697
2725
#else
2698
2726
// scalar
2699
2727
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments