4141
4242#define SOFTMAX_HACK
4343
44- #ifdef __AVX2__
44+ #ifdef __AVX__
4545#include <immintrin.h>
46+
47+
48+ #ifdef __AVX2__
4649static __m256 exp8_approx (__m256 X )
4750{
4851 const __m256 K0 = _mm256_set1_ps (0.99992522f );
@@ -65,7 +68,44 @@ static __m256 exp8_approx(__m256 X)
6568 Y = _mm256_castsi256_ps (_mm256_and_si256 (mask , _mm256_add_epi32 (I , _mm256_castps_si256 (Y ))));
6669 return Y ;
6770}
68-
71+ #else
72+ #define _mm256_fmadd_ps (a ,b ,c ) _mm256_add_ps(_mm256_mul_ps(a, b), c)
73+ #define _mm_fmadd_ps (a ,b ,c ) _mm_add_ps(_mm_mul_ps(a, b), c)
74+ static __m128 exp4_approx (__m128 X )
75+ {
76+ const __m128 K0 = _mm_set1_ps (0.99992522f );
77+ const __m128 K1 = _mm_set1_ps (0.69583354f );
78+ const __m128 K2 = _mm_set1_ps (0.22606716f );
79+ const __m128 K3 = _mm_set1_ps (0.078024523f );
80+ const __m128 log2_E = _mm_set1_ps (1.44269504 );
81+ const __m128 max_in = _mm_set1_ps (50.f );
82+ const __m128 min_in = _mm_set1_ps (-50.f );
83+ const __m128i mask = _mm_set1_epi32 (0x7fffffff );
84+ __m128 XF , Y ;
85+ __m128i I ;
86+ X = _mm_mul_ps (X , log2_E );
87+ X = _mm_max_ps (min_in , _mm_min_ps (max_in , X ));
88+ XF = _mm_floor_ps (X );
89+ I = _mm_cvtps_epi32 (XF );
90+ X = _mm_sub_ps (X , XF );
91+ Y = _mm_fmadd_ps (_mm_fmadd_ps (_mm_fmadd_ps (K3 , X , K2 ), X , K1 ), X , K0 );
92+ I = _mm_slli_epi32 (I , 23 );
93+ Y = _mm_castsi128_ps (_mm_and_si128 (mask , _mm_add_epi32 (I , _mm_castps_si128 (Y ))));
94+ return Y ;
95+ }
96+ static __m256 exp8_approx (__m256 X )
97+ {
98+ __m256 Y ;
99+ __m128 Xhi , Xlo , Yhi , Ylo ;
100+ Xhi = _mm256_extractf128_ps (X , 1 );
101+ Xlo = _mm256_extractf128_ps (X , 0 );
102+ Yhi = exp4_approx (Xhi );
103+ Ylo = exp4_approx (Xlo );
104+ Y = _mm256_insertf128_ps (_mm256_setzero_ps (), Yhi , 1 );
105+ Y = _mm256_insertf128_ps (Y , Ylo , 0 );
106+ return Y ;
107+ }
108+ #endif
69109
70110static float celt_exp (float x )
71111{
0 commit comments