Skip to content

Commit 1bfc153

Browse files
ikawrakowKawrakow
andauthored
ggml : a faster version for Q4_1 x Q8_0 dot products (#1083)
* A faster version for Q4_1 x Q8_0 dot products The idea nehind being that Q8_0 quantized values get used many times in the matrix multiplications where they are involved. In the current implementations, when we are evaluating the dot products, we need to compute the sum of the quants in the Q8_0 vector, so the same operation is repeated many times. Here we pre-compute the sum during Q8_0 quantization, store it in the now modified block_q8_0 struct, and then reuse this result in the subsequent dot products. In a synthetic benchmark (just compute a bunch of dot products), this change speeds up the Q4_1 * Q8_0 dot product by 80%, making the performance identical to Q4_0 * Q8_0. In practical application, I see a ~15% gain in speed for token prediction on M2, and ~5% gain on Ryzen 7950X. The speed gain in the prompt evaluation is much bigger (around 50%). I have only done the change for the scalar version, ARM_NEON, and AVX2, so we still need an AVX implementation. * Cleaning up --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 3d59769 commit 1bfc153

File tree

3 files changed

+251
-46
lines changed

3 files changed

+251
-46
lines changed

Diff for: ggml.c

+74-46
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657657
#define QK8_0 32
658658
typedef struct {
659659
float d; // delta
660+
float s; // d * sum(qs[i])
660661
int8_t qs[QK8_0]; // quants
661662
} block_q8_0;
662-
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663+
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663664

664665

665666
// reference implementation for deterministic creation of model files
@@ -1299,12 +1300,38 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300

13001301
y[i].d = d;
13011302

1303+
int sum = 0;
13021304
for (int l = 0; l < QK8_0; ++l) {
13031305
const float v = x[i*QK8_0 + l]*id;
13041306
y[i].qs[l] = roundf(v);
1305-
}
1306-
}
1307+
sum += y[i].qs[l];
1308+
}
1309+
y[i].s = d * sum;
1310+
}
1311+
}
1312+
1313+
#ifdef __AVX2__
1314+
// There is no better way of doing this?
1315+
// I guess not, AVX is not very good at horizontal sums.
1316+
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317+
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318+
// keeping the original.
1319+
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320+
//static inline float horizontal_sum(__m256i a) {
1321+
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322+
// __m256i sum = _mm256_add_epi32(a, b);
1323+
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324+
// sum = _mm256_add_epi32(sum, hi);
1325+
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326+
//}
1327+
static inline float horizontal_sum(__m256i a) {
1328+
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
1329+
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
1330+
__m128i sum64 = _mm_add_epi32(hi64, sum128);
1331+
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1332+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
13071333
}
1334+
#endif
13081335

13091336
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
13101337
assert(k % QK8_0 == 0);
@@ -1332,6 +1359,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13321359

13331360
y[i].d = d;
13341361

1362+
int32x4_t accv = vdupq_n_s32(0);
1363+
13351364
for (int l = 0; l < 8; l++) {
13361365
const float32x4_t v = vmulq_n_f32(srcv[l], id);
13371366
const int32x4_t vi = vcvtnq_s32_f32(v);
@@ -1340,7 +1369,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13401369
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
13411370
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
13421371
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1372+
1373+
accv = vaddq_s32(accv, vi);
13431374
}
1375+
int32_t sum = vaddvq_s32(accv);
1376+
y[i].s = d * sum;
13441377
}
13451378
#elif defined(__AVX2__) || defined(__AVX__)
13461379
for (int i = 0; i < nb; i++) {
@@ -1388,6 +1421,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13881421
__m256i i3 = _mm256_cvtps_epi32( v3 );
13891422

13901423
#if defined(__AVX2__)
1424+
1425+
// Compute the sum of the quants and set y[i].s
1426+
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1427+
13911428
// Convert int32 to int16
13921429
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13931430
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1467,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
14301467
// scalar
14311468
quantize_row_q8_0_reference(x, y, k);
14321469
#endif
1470+
#if defined __AVX__
1471+
// TODO: vectorize this
1472+
for (int i=0; i<nb; ++i) {
1473+
int sum = 0;
1474+
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
1475+
y[i].s = y[i].d * sum;
1476+
}
1477+
#endif
14331478
}
14341479

14351480
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
@@ -2372,14 +2417,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23722417
float32x4_t sumv0 = vdupq_n_f32(0.0f);
23732418
float32x4_t sumv1 = vdupq_n_f32(0.0f);
23742419

2420+
float sum8 = 0;
2421+
23752422
for (int i = 0; i < nb; i += 2) {
23762423
const block_q4_0 * restrict x0 = &x[i + 0];
23772424
const block_q4_0 * restrict x1 = &x[i + 1];
23782425
const block_q8_0 * restrict y0 = &y[i + 0];
23792426
const block_q8_0 * restrict y1 = &y[i + 1];
23802427

2428+
sum8 += x0->d * y0->s + x1->d * y1->s;
2429+
23812430
const uint8x16_t m4b = vdupq_n_u8(0xf);
2382-
const int8x16_t s8b = vdupq_n_s8(0x8);
23832431

23842432
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
23852433
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2390,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23902438
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
23912439
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
23922440

2393-
// sub 8
2394-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2395-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2396-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2397-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2398-
23992441
// load y
24002442
const int8x16_t v1_0l = vld1q_s8(y0->qs);
24012443
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@@ -2410,21 +2452,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24102452

24112453
#if defined(__ARM_FEATURE_DOTPROD)
24122454
// dot product into int32x4_t
2413-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2455+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2456+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
24152457

24162458
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
24172459
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
24182460
#else
2419-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2420-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2421-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2422-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2461+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2462+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2463+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2464+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
24232465

2424-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2425-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2426-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2427-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2466+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2467+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2468+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2469+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
24282470

24292471
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
24302472
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362478
#endif
24372479
}
24382480

2439-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2481+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
24402482
#elif defined(__AVX2__)
24412483
// Initialize accumulator with zeros
24422484
__m256 acc = _mm256_setzero_ps();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25692611
float32x4_t sumv0 = vdupq_n_f32(0.0f);
25702612
float32x4_t sumv1 = vdupq_n_f32(0.0f);
25712613

2614+
float summs = 0;
2615+
25722616
for (int i = 0; i < nb; i += 2) {
25732617
const block_q4_1 * restrict x0 = &x[i + 0];
25742618
const block_q4_1 * restrict x1 = &x[i + 1];
25752619
const block_q8_0 * restrict y0 = &y[i + 0];
25762620
const block_q8_0 * restrict y1 = &y[i + 1];
25772621

2622+
summs += x0->m * y0->s + x1->m * y1->s;
2623+
25782624
const uint8x16_t m4b = vdupq_n_u8(0xf);
25792625

25802626
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2598,17 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25982644
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
25992645
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
26002646

2601-
const int16x8_t s0i = vaddq_s16(
2602-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604-
2605-
const int16x8_t s1i = vaddq_s16(
2606-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608-
2609-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611-
26122647
#if defined(__ARM_FEATURE_DOTPROD)
26132648
// dot product into int32x4_t
26142649
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
@@ -2637,24 +2672,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372672
#endif
26382673
}
26392674

2640-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2675+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
26412676
#elif defined(__AVX2__)
26422677
// Initialize accumulator with zeros
26432678
__m256 acc = _mm256_setzero_ps();
26442679

2680+
float summs = 0;
2681+
26452682
// Main loop
26462683
for (int i = 0; i < nb; ++i) {
26472684
const float * d0 = &x[i].d;
26482685
const float * d1 = &y[i].d;
2649-
const float * m0 = &x[i].m;
2686+
//const float * m0 = &x[i].m;
2687+
2688+
summs += x[i].m * y[i].s;
26502689

26512690
const __m256 d0v = _mm256_broadcast_ss( d0 );
26522691
const __m256 d1v = _mm256_broadcast_ss( d1 );
2653-
const __m256 m0v = _mm256_broadcast_ss( m0 );
26542692

26552693
// Compute combined scales
26562694
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657-
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
26582695

26592696
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
26602697
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
@@ -2676,15 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26762713

26772714
// Accumulate d0*d1*x*y
26782715
acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679-
2680-
// Compute sum of y values
2681-
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682-
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685-
2686-
// Accumulate d1*m0*y
2687-
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
26882716
}
26892717

26902718
// Return horizontal sum of the acc vector
@@ -2693,7 +2721,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26932721
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
26942722
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
26952723

2696-
sumf = _mm_cvtss_f32( res );
2724+
sumf = _mm_cvtss_f32( res ) + summs;
26972725
#else
26982726
// scalar
26992727
for (int i = 0; i < nb; i++) {

Diff for: pocs/vdot/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@ set(TARGET vdot)
22
add_executable(${TARGET} vdot.cpp)
33
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
44
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
6+
set(TARGET q8dot)
7+
add_executable(${TARGET} q8dot.cpp)
8+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
9+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

0 commit comments

Comments
 (0)