@@ -77,13 +77,81 @@ static OPUS_INLINE float relu(float x)
7777 return x < 0 ? 0 : x ;
7878}
7979
80- static void gemm_accum (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
80+ #ifdef __AVX2__
81+ #include <immintrin.h>
82+ static void gemm_accum16 (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
8183{
8284 int i , j ;
83- for (i = 0 ;i < rows ;i ++ )
85+ for (i = 0 ;i < rows ;i += 16 )
8486 {
87+ float * restrict y ;
88+ __m256 vy0 , vy8 ;
89+ y = & out [i ];
90+ vy0 = _mm256_loadu_ps (& y [0 ]);
91+ vy8 = _mm256_loadu_ps (& y [8 ]);
8592 for (j = 0 ;j < cols ;j ++ )
86- out [i ] += weights [j * col_stride + i ]* x [j ];
93+ {
94+ __m256 vxj ;
95+ __m256 vw ;
96+ vxj = _mm256_broadcast_ss (& x [j ]);
97+
98+ vw = _mm256_loadu_ps (& weights [j * col_stride + i ]);
99+ vy0 = _mm256_fmadd_ps (vw , vxj , vy0 );
100+
101+ vw = _mm256_loadu_ps (& weights [j * col_stride + i + 8 ]);
102+ vy8 = _mm256_fmadd_ps (vw , vxj , vy8 );
103+ }
104+ _mm256_storeu_ps (& y [0 ], vy0 );
105+ _mm256_storeu_ps (& y [8 ], vy8 );
106+ }
107+ }
108+ #else
109+ static void gemm_accum16 (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
110+ {
111+ int i , j ;
112+ for (i = 0 ;i < rows ;i += 16 )
113+ {
114+ for (j = 0 ;j < cols ;j ++ )
115+ {
116+ const float * restrict w ;
117+ float * restrict y ;
118+ float xj ;
119+ w = & weights [j * col_stride + i ];
120+ xj = x [j ];
121+ y = & out [i ];
122+ y [0 ] += w [0 ]* xj ;
123+ y [1 ] += w [1 ]* xj ;
124+ y [2 ] += w [2 ]* xj ;
125+ y [3 ] += w [3 ]* xj ;
126+ y [4 ] += w [4 ]* xj ;
127+ y [5 ] += w [5 ]* xj ;
128+ y [6 ] += w [6 ]* xj ;
129+ y [7 ] += w [7 ]* xj ;
130+ y [8 ] += w [8 ]* xj ;
131+ y [9 ] += w [9 ]* xj ;
132+ y [10 ] += w [10 ]* xj ;
133+ y [11 ] += w [11 ]* xj ;
134+ y [12 ] += w [12 ]* xj ;
135+ y [13 ] += w [13 ]* xj ;
136+ y [14 ] += w [14 ]* xj ;
137+ y [15 ] += w [15 ]* xj ;
138+ }
139+ }
140+ }
141+ #endif
142+
143+ static void gemm_accum (float * out , const float * weights , int rows , int cols , int col_stride , const float * x )
144+ {
145+ int i , j ;
146+ if (rows % 16 == 0 && cols % 16 == 0 )
147+ {
148+ gemm_accum16 (out , weights , rows , cols , col_stride , x );
149+ } else {
150+ for (i = 0 ;i < rows ;i ++ )
151+ {
152+ for (j = 0 ;j < cols ;j ++ )
153+ out [i ] += weights [j * col_stride + i ]* x [j ];
154+ }
87155 }
88156}
89157
0 commit comments