@@ -1122,13 +1122,36 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
1122
1122
#define GGML_F16_EPR 8
1123
1123
1124
1124
// F16 arithmetic is not supported by AVX, so we use F32 instead
1125
- // we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
1126
1125
1127
1126
#define GGML_F32Cx8 __m256
1128
1127
#define GGML_F32Cx8_ZERO _mm256_setzero_ps()
1129
1128
#define GGML_F32Cx8_SET1 (x ) _mm256_set1_ps(x)
1129
+
1130
+ #if defined(__F16C__ )
1131
+ // the _mm256_cvt intrinsics require F16C
1130
1132
#define GGML_F32Cx8_LOAD (x ) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1131
1133
#define GGML_F32Cx8_STORE (x , y ) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1134
+ #else
1135
+ static inline __m256 __avx_f32cx8_load (ggml_fp16_t * x ) {
1136
+ float tmp [8 ];
1137
+
1138
+ for (int i = 0 ; i < 8 ; i ++ )
1139
+ tmp [i ] = GGML_FP16_TO_FP32 (x [i ]);
1140
+
1141
+ return _mm256_loadu_ps (tmp );
1142
+ }
1143
+ static inline void __avx_f32cx8_store (ggml_fp16_t * x , __m256 y ) {
1144
+ float arr [8 ];
1145
+
1146
+ _mm256_storeu_ps (arr , y );
1147
+
1148
+ for (int i = 0 ; i < 8 ; i ++ )
1149
+ x [i ] = GGML_FP16_TO_FP32 (arr [i ]);
1150
+ }
1151
+ #define GGML_F32Cx8_LOAD (x ) __avx_f32cx8_load(x)
1152
+ #define GGML_F32Cx8_STORE (x , y ) __avx_f32cx8_store(x, y)
1153
+ #endif
1154
+
1132
1155
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
1133
1156
#define GGML_F32Cx8_ADD _mm256_add_ps
1134
1157
#define GGML_F32Cx8_MUL _mm256_mul_ps
0 commit comments