@@ -697,7 +697,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
697
697
// method 4
698
698
// blocks of QK elements
699
699
// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
700
- void quantize_row_q4_1 (const float * restrict x , void * restrict y , int k ) {
700
+ void quantize_row_q4_1_reference (const float * restrict x , void * restrict y , int k ) {
701
701
assert (k % QK == 0 );
702
702
703
703
const int nb = k / QK ;
@@ -745,6 +745,102 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
745
745
}
746
746
}
747
747
748
+ void quantize_row_q4_1 (const float * restrict x , void * restrict y , int k ) {
749
+ assert (k % QK == 0 );
750
+
751
+ #if defined(__AVX2__ )
752
+ const int nb = k / QK ;
753
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
754
+
755
+ uint8_t * restrict pd = ((uint8_t * )y + 0 * bs );
756
+ uint8_t * restrict pm = ((uint8_t * )y + 0 * bs + sizeof (float ));
757
+ uint8_t * restrict pb = ((uint8_t * )y + 0 * bs + 2 * sizeof (float ));
758
+
759
+ uint8_t pp [QK /2 ];
760
+
761
+ for (int i = 0 ; i < nb ; i ++ ) {
762
+ // Load elements into 4 AVX vectors
763
+ __m256 v0 = _mm256_loadu_ps ( x );
764
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
765
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
766
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
767
+ x += 32 ;
768
+
769
+ // Compute max for the block
770
+ __m256 vmax ;
771
+ vmax = _mm256_max_ps ( v0 , v1 );
772
+ vmax = _mm256_max_ps ( vmax , v2 );
773
+ vmax = _mm256_max_ps ( vmax , v3 );
774
+
775
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( vmax , 1 ), _mm256_castps256_ps128 ( vmax ) );
776
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
777
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
778
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
779
+
780
+ // Compute min for the block
781
+ __m256 vmin ;
782
+ vmin = _mm256_min_ps ( v0 , v1 );
783
+ vmin = _mm256_min_ps ( vmin , v2 );
784
+ vmin = _mm256_min_ps ( vmin , v3 );
785
+
786
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( vmin , 1 ), _mm256_castps256_ps128 ( vmin ) );
787
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
788
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
789
+ const float minScalar = _mm_cvtss_f32 ( min4 );
790
+
791
+ // Quantize these floats
792
+ const float d = (maxScalar - minScalar ) / ((1 << 4 ) - 1 );
793
+ const float id = d ? 1.0f /d : 0.0f ;
794
+
795
+ * (float * )pm = minScalar ;
796
+ * (float * )pd = d ;
797
+ pm += bs ;
798
+ pd += bs ;
799
+
800
+ // x = (x-min)*id
801
+ const __m256 mul = _mm256_set1_ps ( id );
802
+ const __m256 off = _mm256_set1_ps ( minScalar );
803
+ v0 = _mm256_mul_ps ( _mm256_sub_ps ( v0 , off ), mul );
804
+ v1 = _mm256_mul_ps ( _mm256_sub_ps ( v1 , off ), mul );
805
+ v2 = _mm256_mul_ps ( _mm256_sub_ps ( v2 , off ), mul );
806
+ v3 = _mm256_mul_ps ( _mm256_sub_ps ( v3 , off ), mul );
807
+
808
+ // Round to nearest integer
809
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
810
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
811
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
812
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
813
+
814
+ // Convert floats to integers
815
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
816
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
817
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
818
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
819
+
820
+ // Convert int32 to int16
821
+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
822
+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
823
+ // Convert int16 to int8
824
+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
825
+
826
+ // We got our precious signed bytes, but the order is now wrong
827
+ // These AVX2 pack instructions process 16-byte pieces independently
828
+ // The following instruction is fixing the order
829
+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
830
+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
831
+
832
+ // Compress the vector into 4 bit/value, and store
833
+ __m128i res = packNibbles ( i0 );
834
+ _mm_storeu_si128 ( ( __m128i * )pb , res );
835
+
836
+ pb += bs ;
837
+ }
838
+ #else
839
+ // scalar
840
+ quantize_row_q4_1_reference (x , y , k );
841
+ #endif
842
+ }
843
+
748
844
// TODO: vectorize
749
845
void dequantize_row_q4_0 (const void * restrict x , float * restrict y , int k ) {
750
846
assert (k % QK == 0 );
@@ -10398,7 +10494,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, i
10398
10494
uint8_t * pd = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs );
10399
10495
uint8_t * pb = (uint8_t * ) (pdst + (j /k )* row_size + 0 * bs + 2 * sizeof (float ));
10400
10496
10401
- quantize_row_q4_1 (src + j , pd , k );
10497
+ quantize_row_q4_1_reference (src + j , pd , k );
10402
10498
10403
10499
for (int i = 0 ; i < nb ; i ++ ) {
10404
10500
for (int l = 0 ; l < qk ; l += 2 ) {
0 commit comments