Skip to content

Commit d16a111

Browse files
committed
Quick and dirty AVX2 implementation of gemm_accum
Brings us very close to real-time
1 parent 792c5ec commit d16a111

File tree

1 file changed

+71
-3
lines changed

1 file changed

+71
-3
lines changed

src/nnet.c

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,81 @@ static OPUS_INLINE float relu(float x)
7777
return x < 0 ? 0 : x;
7878
}
7979

80-
static void gemm_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
80+
#ifdef __AVX2__
81+
#include <immintrin.h>
82+
static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
8183
{
8284
int i, j;
83-
for (i=0;i<rows;i++)
85+
for (i=0;i<rows;i+=16)
8486
{
87+
float * restrict y;
88+
__m256 vy0, vy8;
89+
y = &out[i];
90+
vy0 = _mm256_loadu_ps(&y[0]);
91+
vy8 = _mm256_loadu_ps(&y[8]);
8592
for (j=0;j<cols;j++)
86-
out[i] += weights[j*col_stride + i]*x[j];
93+
{
94+
__m256 vxj;
95+
__m256 vw;
96+
vxj = _mm256_broadcast_ss(&x[j]);
97+
98+
vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
99+
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
100+
101+
vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
102+
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
103+
}
104+
_mm256_storeu_ps (&y[0], vy0);
105+
_mm256_storeu_ps (&y[8], vy8);
106+
}
107+
}
108+
#else
109+
static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
110+
{
111+
int i, j;
112+
for (i=0;i<rows;i+=16)
113+
{
114+
for (j=0;j<cols;j++)
115+
{
116+
const float * restrict w;
117+
float * restrict y;
118+
float xj;
119+
w = &weights[j*col_stride + i];
120+
xj = x[j];
121+
y = &out[i];
122+
y[0] += w[0]*xj;
123+
y[1] += w[1]*xj;
124+
y[2] += w[2]*xj;
125+
y[3] += w[3]*xj;
126+
y[4] += w[4]*xj;
127+
y[5] += w[5]*xj;
128+
y[6] += w[6]*xj;
129+
y[7] += w[7]*xj;
130+
y[8] += w[8]*xj;
131+
y[9] += w[9]*xj;
132+
y[10] += w[10]*xj;
133+
y[11] += w[11]*xj;
134+
y[12] += w[12]*xj;
135+
y[13] += w[13]*xj;
136+
y[14] += w[14]*xj;
137+
y[15] += w[15]*xj;
138+
}
139+
}
140+
}
141+
#endif
142+
143+
static void gemm_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
144+
{
145+
int i, j;
146+
if (rows % 16 == 0 && cols % 16 == 0)
147+
{
148+
gemm_accum16(out, weights, rows, cols, col_stride, x);
149+
} else {
150+
for (i=0;i<rows;i++)
151+
{
152+
for (j=0;j<cols;j++)
153+
out[i] += weights[j*col_stride + i]*x[j];
154+
}
87155
}
88156
}
89157

0 commit comments

Comments
 (0)