Skip to content

Commit fd80992

Browse files
committed
Delaying the softmax() to avoid the pow()
Now at 5x real-time, with all the low-hanging fruit done.
1 parent 432a4c2 commit fd80992

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/nnet.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
#include "nnet.h"
4040
#include "nnet_data.h"
4141

42+
#define SOFTMAX_HACK
43+
4244
#ifdef __AVX2__
4345
#include <immintrin.h>
4446
static __m256 exp8_approx(__m256 X)
@@ -340,6 +342,10 @@ void compute_activation(float *output, float *input, int N, int activation)
340342
for (i=0;i<N;i++)
341343
output[i] = relu(input[i]);
342344
} else if (activation == ACTIVATION_SOFTMAX) {
345+
#ifdef SOFTMAX_HACK
346+
for (i=0;i<N;i++)
347+
output[i] = input[i];
348+
#else
343349
float sum = 0;
344350
softmax(output, input, N);
345351
for (i=0;i<N;i++) {
@@ -348,6 +354,7 @@ void compute_activation(float *output, float *input, int N, int activation)
348354
sum = 1.f/(sum+1e-30);
349355
for (i=0;i<N;i++)
350356
output[i] = sum*output[i];
357+
#endif
351358
} else {
352359
celt_assert(activation == ACTIVATION_LINEAR);
353360
for (i=0;i<N;i++)
@@ -619,12 +626,24 @@ int sample_from_pdf(const float *pdf, int N, float exp_boost, float pdf_floor)
619626
float tmp[DUAL_FC_OUT_SIZE];
620627
celt_assert(N <= DUAL_FC_OUT_SIZE);
621628
sum = 0;
629+
#ifdef SOFTMAX_HACK
630+
for (i=0;i<N;i++)
631+
{
632+
tmp[i] = pdf[i] * (1.f+exp_boost);
633+
}
634+
softmax(tmp, tmp, N);
635+
for (i=0;i<N;i++)
636+
{
637+
sum += tmp[i];
638+
}
639+
#else
622640
/* Decrease the temperature of the sampling. */
623641
for (i=0;i<N;i++)
624642
{
625643
tmp[i] = pow(pdf[i], 1.f+exp_boost);
626644
sum += tmp[i];
627645
}
646+
#endif
628647
norm = 1.f/sum;
629648
/* Convert tmp to a CDF while subtracting the floor */
630649
tmp[0] = MAX16(0, norm*tmp[0] - pdf_floor);

0 commit comments

Comments
 (0)