Skip to content

Commit 6696d2c

Browse files
committed
Add AXV versions of exp(), tanh() and sigmoid()
Now 3x faster than real-time
1 parent 7df3f9c commit 6696d2c

File tree

1 file changed

+146
-6
lines changed

1 file changed

+146
-6
lines changed

src/nnet.c

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,119 @@
3939
#include "nnet.h"
4040
#include "nnet_data.h"
4141

42-
static OPUS_INLINE float tansig_approx(float x)
42+
#ifdef __AVX2__
43+
#include <immintrin.h>
44+
static __m256 exp8_approx(__m256 X)
45+
{
46+
const __m256 K0 = _mm256_set1_ps(0.99992522f);
47+
const __m256 K1 = _mm256_set1_ps(0.69583354f);
48+
const __m256 K2 = _mm256_set1_ps(0.22606716f);
49+
const __m256 K3 = _mm256_set1_ps(0.078024523f);
50+
const __m256 log2_E = _mm256_set1_ps(1.44269504);
51+
const __m256 max_in = _mm256_set1_ps(50.f);
52+
const __m256 min_in = _mm256_set1_ps(-50.f);
53+
const __m256i mask = _mm256_set1_epi32(0x7fffffff);
54+
__m256 XF, Y;
55+
__m256i I;
56+
X = _mm256_mul_ps(X, log2_E);
57+
X = _mm256_max_ps(min_in, _mm256_min_ps(max_in, X));
58+
XF = _mm256_floor_ps(X);
59+
I = _mm256_cvtps_epi32(XF);
60+
X = _mm256_sub_ps(X, XF);
61+
Y = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(K3, X, K2), X, K1), X, K0);
62+
I = _mm256_slli_epi32(I, 23);
63+
Y = _mm256_castsi256_ps(_mm256_and_si256(mask, _mm256_add_epi32(I, _mm256_castps_si256(Y))));
64+
return Y;
65+
}
66+
67+
68+
static float celt_exp(float x)
69+
{
70+
float out[8];
71+
__m256 X, Y;
72+
X = _mm256_set1_ps(x);
73+
Y = exp8_approx(X);
74+
_mm256_storeu_ps(out, Y);
75+
return out[0];
76+
}
77+
78+
static void softmax(float *y, const float *x, int N)
79+
{
80+
int i;
81+
for (i=0;i<N-7;i+=8)
82+
{
83+
__m256 X, Y;
84+
X = _mm256_loadu_ps(&x[i]);
85+
Y = exp8_approx(X);
86+
_mm256_storeu_ps(&y[i], Y);
87+
}
88+
for (;i<N;i++)
89+
y[i] = celt_exp(x[i]);
90+
}
91+
92+
static void vec_tanh(float *y, const float *x, int N)
93+
{
94+
int i;
95+
for (i=0;i<N-7;i+=8)
96+
{
97+
const __m256 two = _mm256_set1_ps(2.f);
98+
const __m256 one = _mm256_set1_ps(1.f);
99+
__m256 X, Y;
100+
X = _mm256_loadu_ps(&x[i]);
101+
X = _mm256_mul_ps(X, two);
102+
Y = exp8_approx(X);
103+
Y = _mm256_mul_ps(_mm256_sub_ps(Y, one), _mm256_rcp_ps(_mm256_add_ps(Y, one)));
104+
_mm256_storeu_ps(&y[i], Y);
105+
}
106+
for (;i<N;i++)
107+
{
108+
float ex2;
109+
ex2 = celt_exp(2*x[i]);
110+
y[i] = (ex2-1)/(ex2+1);
111+
}
112+
}
113+
114+
static void vec_sigmoid(float *y, const float *x, int N)
115+
{
116+
int i;
117+
for (i=0;i<N-7;i+=8)
118+
{
119+
const __m256 one = _mm256_set1_ps(1.f);
120+
__m256 X, Y;
121+
X = _mm256_loadu_ps(&x[i]);
122+
Y = exp8_approx(X);
123+
Y = _mm256_mul_ps(Y, _mm256_rcp_ps(_mm256_add_ps(Y, one)));
124+
_mm256_storeu_ps(&y[i], Y);
125+
}
126+
for (;i<N;i++)
127+
{
128+
float ex;
129+
ex = celt_exp(x[i]);
130+
y[i] = (ex)/(ex+1);
131+
}
132+
}
133+
#else
134+
static float celt_exp2(float x)
135+
{
136+
int integer;
137+
float frac;
138+
union {
139+
float f;
140+
opus_uint32 i;
141+
} res;
142+
integer = floor(x);
143+
if (integer < -50)
144+
return 0;
145+
frac = x-integer;
146+
/* K0 = 1, K1 = log(2), K2 = 3-4*log(2), K3 = 3*log(2) - 2 */
147+
res.f = 0.99992522f + frac * (0.69583354f
148+
+ frac * (0.22606716f + 0.078024523f*frac));
149+
res.i = (res.i + (integer<<23)) & 0x7fffffff;
150+
return res.f;
151+
}
152+
#define celt_exp(x) celt_exp2((x)*1.44269504f)
153+
154+
static float tansig_approx(float x)
43155
{
44156
int i;
45157
float y, dy;
@@ -72,6 +184,36 @@ static OPUS_INLINE float sigmoid_approx(float x)
72184
return .5f + .5f*tansig_approx(.5f*x);
73185
}
74186

187+
static void softmax(float *y, const float *x, int N)
188+
{
189+
int i;
190+
for (i=0;i<N;i++)
191+
y[i] = celt_exp(x[i]);
192+
}
193+
194+
static void vec_tanh(float *y, const float *x, int N)
195+
{
196+
int i;
197+
for (i=0;i<N;i++)
198+
{
199+
y[i] = tansig_approx(x[i]);
200+
}
201+
}
202+
203+
static void vec_sigmoid(float *y, const float *x, int N)
204+
{
205+
int i;
206+
for (i=0;i<N;i++)
207+
{
208+
y[i] = sigmoid_approx(x[i]);
209+
}
210+
}
211+
212+
213+
#endif
214+
215+
216+
75217
static OPUS_INLINE float relu(float x)
76218
{
77219
return x < 0 ? 0 : x;
@@ -191,18 +333,16 @@ void compute_activation(float *output, float *input, int N, int activation)
191333
{
192334
int i;
193335
if (activation == ACTIVATION_SIGMOID) {
194-
for (i=0;i<N;i++)
195-
output[i] = sigmoid_approx(input[i]);
336+
vec_sigmoid(output, input, N);
196337
} else if (activation == ACTIVATION_TANH) {
197-
for (i=0;i<N;i++)
198-
output[i] = tansig_approx(input[i]);
338+
vec_tanh(output, input, N);
199339
} else if (activation == ACTIVATION_RELU) {
200340
for (i=0;i<N;i++)
201341
output[i] = relu(input[i]);
202342
} else if (activation == ACTIVATION_SOFTMAX) {
203343
float sum = 0;
344+
softmax(output, input, N);
204345
for (i=0;i<N;i++) {
205-
output[i] = exp(input[i]);
206346
sum += output[i];
207347
}
208348
sum = 1.f/(sum+1e-30);

0 commit comments

Comments
 (0)