Skip to content

Commit 7dca16b

Browse files
committed
Add AVX2 implementation of quantize_row_q4_1
1 parent 1972616 commit 7dca16b

File tree

1 file changed

+98
-2
lines changed

1 file changed

+98
-2
lines changed

Diff for: ggml.c

+98-2
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
697697
// method 4
698698
// blocks of QK elements
699699
// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
700-
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
700+
void quantize_row_q4_1_reference(const float * restrict x, void * restrict y, int k) {
701701
assert(k % QK == 0);
702702

703703
const int nb = k / QK;
@@ -745,6 +745,102 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
745745
}
746746
}
747747

748+
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
749+
assert(k % QK == 0);
750+
751+
#if defined(__AVX2__)
752+
const int nb = k / QK;
753+
const size_t bs = 2*sizeof(float) + QK/2;
754+
755+
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
756+
uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
757+
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
758+
759+
uint8_t pp[QK/2];
760+
761+
for (int i = 0; i < nb; i++) {
762+
// Load elements into 4 AVX vectors
763+
__m256 v0 = _mm256_loadu_ps( x );
764+
__m256 v1 = _mm256_loadu_ps( x + 8 );
765+
__m256 v2 = _mm256_loadu_ps( x + 16 );
766+
__m256 v3 = _mm256_loadu_ps( x + 24 );
767+
x += 32;
768+
769+
// Compute max for the block
770+
__m256 vmax;
771+
vmax = _mm256_max_ps( v0, v1 );
772+
vmax = _mm256_max_ps( vmax, v2 );
773+
vmax = _mm256_max_ps( vmax, v3 );
774+
775+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
776+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
777+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
778+
const float maxScalar = _mm_cvtss_f32( max4 );
779+
780+
// Compute min for the block
781+
__m256 vmin;
782+
vmin = _mm256_min_ps( v0, v1 );
783+
vmin = _mm256_min_ps( vmin, v2 );
784+
vmin = _mm256_min_ps( vmin, v3 );
785+
786+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
787+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
788+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
789+
const float minScalar = _mm_cvtss_f32( min4 );
790+
791+
// Quantize these floats
792+
const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
793+
const float id = d ? 1.0f/d : 0.0f;
794+
795+
*(float *)pm = minScalar;
796+
*(float *)pd = d;
797+
pm += bs;
798+
pd += bs;
799+
800+
// x = (x-min)*id
801+
const __m256 mul = _mm256_set1_ps( id );
802+
const __m256 off = _mm256_set1_ps( minScalar );
803+
v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
804+
v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
805+
v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
806+
v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
807+
808+
// Round to nearest integer
809+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
810+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
811+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
812+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
813+
814+
// Convert floats to integers
815+
__m256i i0 = _mm256_cvtps_epi32( v0 );
816+
__m256i i1 = _mm256_cvtps_epi32( v1 );
817+
__m256i i2 = _mm256_cvtps_epi32( v2 );
818+
__m256i i3 = _mm256_cvtps_epi32( v3 );
819+
820+
// Convert int32 to int16
821+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
822+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
823+
// Convert int16 to int8
824+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
825+
826+
// We got our precious signed bytes, but the order is now wrong
827+
// These AVX2 pack instructions process 16-byte pieces independently
828+
// The following instruction is fixing the order
829+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
830+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
831+
832+
// Compress the vector into 4 bit/value, and store
833+
__m128i res = packNibbles( i0 );
834+
_mm_storeu_si128( ( __m128i* )pb, res );
835+
836+
pb += bs;
837+
}
838+
#else
839+
// scalar
840+
quantize_row_q4_1_reference(x, y, k);
841+
#endif
842+
}
843+
748844
// TODO: vectorize
749845
void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
750846
assert(k % QK == 0);
@@ -10398,7 +10494,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, i
1039810494
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
1039910495
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
1040010496

10401-
quantize_row_q4_1(src + j, pd, k);
10497+
quantize_row_q4_1_reference(src + j, pd, k);
1040210498

1040310499
for (int i = 0; i < nb; i++) {
1040410500
for (int l = 0; l < qk; l += 2) {

0 commit comments

Comments
 (0)