@@ -132,7 +132,73 @@ static void vec_sigmoid(float *y, const float *x, int N)
132132 y [i ] = (ex )/(ex + 1 );
133133 }
134134}
135- #else
135+
136+ static void gemm_accum16 (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
137+ {
138+ int i , j ;
139+ for (i = 0 ;i < rows ;i += 16 )
140+ {
141+ float * restrict y ;
142+ __m256 vy0 , vy8 ;
143+ y = & out [i ];
144+ vy0 = _mm256_loadu_ps (& y [0 ]);
145+ vy8 = _mm256_loadu_ps (& y [8 ]);
146+ for (j = 0 ;j < cols ;j ++ )
147+ {
148+ __m256 vxj ;
149+ __m256 vw ;
150+ vxj = _mm256_broadcast_ss (& x [j ]);
151+
152+ vw = _mm256_loadu_ps (& weights [j * col_stride + i ]);
153+ vy0 = _mm256_fmadd_ps (vw , vxj , vy0 );
154+
155+ vw = _mm256_loadu_ps (& weights [j * col_stride + i + 8 ]);
156+ vy8 = _mm256_fmadd_ps (vw , vxj , vy8 );
157+ }
158+ _mm256_storeu_ps (& y [0 ], vy0 );
159+ _mm256_storeu_ps (& y [8 ], vy8 );
160+ }
161+ }
162+ static void sparse_gemm_accum16 (float * out , const float * weights , int rows , const int * idx , const float * x )
163+ {
164+ int i , j ;
165+ for (i = 0 ;i < rows ;i += 16 )
166+ {
167+ float * restrict y ;
168+ int cols ;
169+ __m256 vy0 , vy8 ;
170+ y = & out [i ];
171+ vy0 = _mm256_loadu_ps (& y [0 ]);
172+ vy8 = _mm256_loadu_ps (& y [8 ]);
173+ cols = * idx ++ ;
174+ for (j = 0 ;j < cols ;j ++ )
175+ {
176+ int id ;
177+ __m256 vxj ;
178+ __m256 vw ;
179+ id = * idx ++ ;
180+ vxj = _mm256_broadcast_ss (& x [id ]);
181+
182+ vw = _mm256_loadu_ps (& weights [0 ]);
183+ vy0 = _mm256_fmadd_ps (vw , vxj , vy0 );
184+
185+ vw = _mm256_loadu_ps (& weights [8 ]);
186+ vy8 = _mm256_fmadd_ps (vw , vxj , vy8 );
187+ weights += 16 ;
188+ }
189+ _mm256_storeu_ps (& y [0 ], vy0 );
190+ _mm256_storeu_ps (& y [8 ], vy8 );
191+ }
192+ }
193+
194+
195+ #else /* No AVX2/FMA support */
196+
197+
198+ #warning Compiling without any vectorization. This code will be very slow
199+ #warning Try adding -mavx2 -mfma
200+
201+
136202static float celt_exp2 (float x )
137203{
138204 int integer ;
@@ -211,81 +277,6 @@ static void vec_sigmoid(float *y, const float *x, int N)
211277 }
212278}
213279
214-
215- #endif
216-
217-
218-
219- static OPUS_INLINE float relu (float x )
220- {
221- return x < 0 ? 0 : x ;
222- }
223-
224- #ifdef __AVX2__
225- #include <immintrin.h>
226- static void gemm_accum16 (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
227- {
228- int i , j ;
229- for (i = 0 ;i < rows ;i += 16 )
230- {
231- float * restrict y ;
232- __m256 vy0 , vy8 ;
233- y = & out [i ];
234- vy0 = _mm256_loadu_ps (& y [0 ]);
235- vy8 = _mm256_loadu_ps (& y [8 ]);
236- for (j = 0 ;j < cols ;j ++ )
237- {
238- __m256 vxj ;
239- __m256 vw ;
240- vxj = _mm256_broadcast_ss (& x [j ]);
241-
242- vw = _mm256_loadu_ps (& weights [j * col_stride + i ]);
243- vy0 = _mm256_fmadd_ps (vw , vxj , vy0 );
244-
245- vw = _mm256_loadu_ps (& weights [j * col_stride + i + 8 ]);
246- vy8 = _mm256_fmadd_ps (vw , vxj , vy8 );
247- }
248- _mm256_storeu_ps (& y [0 ], vy0 );
249- _mm256_storeu_ps (& y [8 ], vy8 );
250- }
251- }
252- static void sparse_gemm_accum16 (float * out , const float * weights , int rows , const int * idx , const float * x )
253- {
254- int i , j ;
255- for (i = 0 ;i < rows ;i += 16 )
256- {
257- float * restrict y ;
258- int cols ;
259- __m256 vy0 , vy8 ;
260- y = & out [i ];
261- vy0 = _mm256_loadu_ps (& y [0 ]);
262- vy8 = _mm256_loadu_ps (& y [8 ]);
263- cols = * idx ++ ;
264- for (j = 0 ;j < cols ;j ++ )
265- {
266- int id ;
267- __m256 vxj ;
268- __m256 vw ;
269- id = * idx ++ ;
270- vxj = _mm256_broadcast_ss (& x [id ]);
271-
272- vw = _mm256_loadu_ps (& weights [0 ]);
273- vy0 = _mm256_fmadd_ps (vw , vxj , vy0 );
274-
275- vw = _mm256_loadu_ps (& weights [8 ]);
276- vy8 = _mm256_fmadd_ps (vw , vxj , vy8 );
277- weights += 16 ;
278- }
279- _mm256_storeu_ps (& y [0 ], vy0 );
280- _mm256_storeu_ps (& y [8 ], vy8 );
281- }
282- }
283-
284- #else
285-
286- #warning Compiling without any vectorization. This code will be very slow
287- #warning Try adding -mavx2 -mfma
288-
289280static void gemm_accum16 (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
290281{
291282 int i , j ;
@@ -354,6 +345,14 @@ static void sparse_gemm_accum16(float *out, const float *w, int rows, const int
354345}
355346#endif
356347
348+
349+
350+ static OPUS_INLINE float relu (float x )
351+ {
352+ return x < 0 ? 0 : x ;
353+ }
354+
355+
357356static void gemm_accum (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
358357{
359358 int i , j ;
0 commit comments