Skip to content

Commit 98e3b51

Browse files
committed
moving code around
1 parent 08dfb6c commit 98e3b51

File tree

1 file changed

+75
-76
lines changed

1 file changed

+75
-76
lines changed

src/nnet.c

Lines changed: 75 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,73 @@ static void vec_sigmoid(float *y, const float *x, int N)
132132
y[i] = (ex)/(ex+1);
133133
}
134134
}
135-
#else
135+
136+
static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
137+
{
138+
int i, j;
139+
for (i=0;i<rows;i+=16)
140+
{
141+
float * restrict y;
142+
__m256 vy0, vy8;
143+
y = &out[i];
144+
vy0 = _mm256_loadu_ps(&y[0]);
145+
vy8 = _mm256_loadu_ps(&y[8]);
146+
for (j=0;j<cols;j++)
147+
{
148+
__m256 vxj;
149+
__m256 vw;
150+
vxj = _mm256_broadcast_ss(&x[j]);
151+
152+
vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
153+
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
154+
155+
vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
156+
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
157+
}
158+
_mm256_storeu_ps (&y[0], vy0);
159+
_mm256_storeu_ps (&y[8], vy8);
160+
}
161+
}
162+
static void sparse_gemm_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
163+
{
164+
int i, j;
165+
for (i=0;i<rows;i+=16)
166+
{
167+
float * restrict y;
168+
int cols;
169+
__m256 vy0, vy8;
170+
y = &out[i];
171+
vy0 = _mm256_loadu_ps(&y[0]);
172+
vy8 = _mm256_loadu_ps(&y[8]);
173+
cols = *idx++;
174+
for (j=0;j<cols;j++)
175+
{
176+
int id;
177+
__m256 vxj;
178+
__m256 vw;
179+
id = *idx++;
180+
vxj = _mm256_broadcast_ss(&x[id]);
181+
182+
vw = _mm256_loadu_ps(&weights[0]);
183+
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
184+
185+
vw = _mm256_loadu_ps(&weights[8]);
186+
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
187+
weights += 16;
188+
}
189+
_mm256_storeu_ps (&y[0], vy0);
190+
_mm256_storeu_ps (&y[8], vy8);
191+
}
192+
}
193+
194+
195+
#else /* No AVX2/FMA support */
196+
197+
198+
#warning Compiling without any vectorization. This code will be very slow
199+
#warning Try adding -mavx2 -mfma
200+
201+
136202
static float celt_exp2(float x)
137203
{
138204
int integer;
@@ -211,81 +277,6 @@ static void vec_sigmoid(float *y, const float *x, int N)
211277
}
212278
}
213279

214-
215-
#endif
216-
217-
218-
219-
static OPUS_INLINE float relu(float x)
220-
{
221-
return x < 0 ? 0 : x;
222-
}
223-
224-
#ifdef __AVX2__
225-
#include <immintrin.h>
226-
static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
227-
{
228-
int i, j;
229-
for (i=0;i<rows;i+=16)
230-
{
231-
float * restrict y;
232-
__m256 vy0, vy8;
233-
y = &out[i];
234-
vy0 = _mm256_loadu_ps(&y[0]);
235-
vy8 = _mm256_loadu_ps(&y[8]);
236-
for (j=0;j<cols;j++)
237-
{
238-
__m256 vxj;
239-
__m256 vw;
240-
vxj = _mm256_broadcast_ss(&x[j]);
241-
242-
vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
243-
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
244-
245-
vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
246-
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
247-
}
248-
_mm256_storeu_ps (&y[0], vy0);
249-
_mm256_storeu_ps (&y[8], vy8);
250-
}
251-
}
252-
static void sparse_gemm_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
253-
{
254-
int i, j;
255-
for (i=0;i<rows;i+=16)
256-
{
257-
float * restrict y;
258-
int cols;
259-
__m256 vy0, vy8;
260-
y = &out[i];
261-
vy0 = _mm256_loadu_ps(&y[0]);
262-
vy8 = _mm256_loadu_ps(&y[8]);
263-
cols = *idx++;
264-
for (j=0;j<cols;j++)
265-
{
266-
int id;
267-
__m256 vxj;
268-
__m256 vw;
269-
id = *idx++;
270-
vxj = _mm256_broadcast_ss(&x[id]);
271-
272-
vw = _mm256_loadu_ps(&weights[0]);
273-
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
274-
275-
vw = _mm256_loadu_ps(&weights[8]);
276-
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
277-
weights += 16;
278-
}
279-
_mm256_storeu_ps (&y[0], vy0);
280-
_mm256_storeu_ps (&y[8], vy8);
281-
}
282-
}
283-
284-
#else
285-
286-
#warning Compiling without any vectorization. This code will be very slow
287-
#warning Try adding -mavx2 -mfma
288-
289280
static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
290281
{
291282
int i, j;
@@ -354,6 +345,14 @@ static void sparse_gemm_accum16(float *out, const float *w, int rows, const int
354345
}
355346
#endif
356347

348+
349+
350+
static OPUS_INLINE float relu(float x)
351+
{
352+
return x < 0 ? 0 : x;
353+
}
354+
355+
357356
static void gemm_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
358357
{
359358
int i, j;

0 commit comments

Comments
 (0)