@@ -931,6 +931,101 @@ inline static float vaddvq_f32(float32x4_t v) {
931931    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
932932#endif
933933
934+ #elif defined(__AVX512F__)
935+ 
936+ #define GGML_SIMD
937+ 
938+ // F32 AVX512
939+ 
940+ #define GGML_F32_STEP 64
941+ #define GGML_F32_EPR  16
942+ 
943+ #define GGML_F32x16         __m512
944+ #define GGML_F32x16_ZERO    _mm512_setzero_ps()
945+ #define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
946+ #define GGML_F32x16_LOAD    _mm512_loadu_ps
947+ #define GGML_F32x16_STORE   _mm512_storeu_ps
948+ // _mm512_fmadd_ps is defined in AVX512F so no guard is required
949+ #define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
950+ #define GGML_F32x16_ADD     _mm512_add_ps
951+ #define GGML_F32x16_MUL     _mm512_mul_ps
952+ #define GGML_F32x16_REDUCE(res, x)                                    \
953+ do {                                                                  \
954+     int offset = GGML_F32_ARR >> 1;                                   \
955+     for (int i = 0; i < offset; ++i) {                                \
956+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
957+     }                                                                 \
958+     offset >>= 1;                                                     \
959+     for (int i = 0; i < offset; ++i) {                                \
960+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
961+     }                                                                 \
962+     offset >>= 1;                                                     \
963+     for (int i = 0; i < offset; ++i) {                                \
964+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
965+     }                                                                 \
966+     res = _mm512_reduce_add_ps(x[0]);                                 \
967+ } while (0)
968+ 
969+ // TODO: is this optimal ?
970+ 
971+ #define GGML_F32_VEC        GGML_F32x16
972+ #define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
973+ #define GGML_F32_VEC_SET1   GGML_F32x16_SET1
974+ #define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
975+ #define GGML_F32_VEC_STORE  GGML_F32x16_STORE
976+ #define GGML_F32_VEC_FMA    GGML_F32x16_FMA
977+ #define GGML_F32_VEC_ADD    GGML_F32x16_ADD
978+ #define GGML_F32_VEC_MUL    GGML_F32x16_MUL
979+ #define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
980+ 
981+ // F16 AVX512
982+ 
983+ // F16 AVX
984+ 
985+ #define GGML_F16_STEP 64
986+ #define GGML_F16_EPR  16
987+ 
988+ // AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
989+ 
990+ #define GGML_F32Cx16             __m512
991+ #define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
992+ #define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
993+ 
994+ // unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
995+ // so F16C guard isn't required
996+ #define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
997+ #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
998+ 
999+ #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
1000+ #define GGML_F32Cx16_ADD         _mm512_add_ps
1001+ #define GGML_F32Cx16_MUL         _mm512_mul_ps
1002+ #define GGML_F32Cx16_REDUCE(res, x)                               \
1003+ do {                                                              \
1004+     int offset = GGML_F32_ARR >> 1;                               \
1005+     for (int i = 0; i < offset; ++i) {                            \
1006+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
1007+     }                                                             \
1008+     offset >>= 1;                                                 \
1009+     for (int i = 0; i < offset; ++i) {                            \
1010+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
1011+     }                                                             \
1012+     offset >>= 1;                                                 \
1013+     for (int i = 0; i < offset; ++i) {                            \
1014+         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
1015+     }                                                             \
1016+     res = _mm512_reduce_add_ps(x[0]);                             \
1017+ } while (0)
1018+ 
1019+ #define GGML_F16_VEC                GGML_F32Cx16
1020+ #define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
1021+ #define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
1022+ #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
1023+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
1024+ #define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
1025+ #define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
1026+ #define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
1027+ #define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
1028+ 
9341029#elif defined(__AVX__)
9351030
9361031#define GGML_SIMD
0 commit comments