Skip to content

Commit 3fcd879

Browse files
committed
Simper GRU implementation just for reset_after.
1 parent c74488b commit 3fcd879

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

src/lpcnet.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ void run_sample_network(NNetState *net, float *pdf, const float *condition, int
121121
compute_embedding(&embed_sig, &in_a[EMBED_SIG_OUT_SIZE], pred);
122122
compute_embedding(&embed_exc, &in_a[2*EMBED_SIG_OUT_SIZE], last_exc);
123123
RNN_COPY(&in_a[2*EMBED_SIG_OUT_SIZE + EMBED_EXC_OUT_SIZE], condition, FEATURE_DENSE2_OUT_SIZE);
124-
compute_gru(&gru_a, net->gru_a_state, in_a);
124+
compute_gru2(&gru_a, net->gru_a_state, in_a);
125125
RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE);
126126
RNN_COPY(&in_b[GRU_A_STATE_SIZE], condition, FEATURE_DENSE2_OUT_SIZE);
127-
compute_gru(&gru_b, net->gru_b_state, in_b);
127+
compute_gru2(&gru_b, net->gru_b_state, in_b);
128128
compute_mdense(&dual_fc, pdf, net->gru_b_state);
129129
}
130130

src/nnet.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,44 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
218218
state[i] = h[i];
219219
}
220220

221+
void compute_gru2(const GRULayer *gru, float *state, const float *input)
222+
{
223+
int i;
224+
int N, M;
225+
int stride;
226+
float zrh[3*MAX_RNN_NEURONS];
227+
float recur[3*MAX_RNN_NEURONS];
228+
float *z;
229+
float *r;
230+
float *h;
231+
M = gru->nb_inputs;
232+
N = gru->nb_neurons;
233+
z = zrh;
234+
r = &zrh[N];
235+
h = &zrh[2*N];
236+
celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS);
237+
celt_assert(input != state);
238+
celt_assert(gru->reset_after);
239+
stride = 3*N;
240+
/* Compute update gate. */
241+
for (i=0;i<3*N;i++)
242+
zrh[i] = gru->bias[i];
243+
gemm_accum(zrh, gru->input_weights, 3*N, M, stride, input);
244+
for (i=0;i<3*N;i++)
245+
recur[i] = gru->bias[3*N + i];
246+
gemm_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
247+
for (i=0;i<2*N;i++)
248+
zrh[i] += recur[i];
249+
compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
250+
for (i=0;i<N;i++)
251+
h[i] += recur[2*N+i]*r[i];
252+
compute_activation(h, h, N, gru->activation);
253+
for (i=0;i<N;i++)
254+
h[i] = z[i]*state[i] + (1-z[i])*h[i];
255+
for (i=0;i<N;i++)
256+
state[i] = h[i];
257+
}
258+
221259
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
222260
{
223261
int i;

src/nnet.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ void compute_mdense(const MDenseLayer *layer, float *output, const float *input)
8585

8686
void compute_gru(const GRULayer *gru, float *state, const float *input);
8787

88+
void compute_gru2(const GRULayer *gru, float *state, const float *input);
89+
8890
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input);
8991

9092
void compute_embedding(const EmbeddingLayer *layer, float *output, int input);

0 commit comments

Comments
 (0)