Skip to content

Commit 2fb5069

Browse files
committed
Fix reset_after GRU
1 parent f2d28b9 commit 2fb5069

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/nnet.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,23 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
173173
/* Compute update gate. */
174174
for (i=0;i<N;i++)
175175
z[i] = gru->bias[i];
176+
if (gru->reset_after)
177+
{
178+
for (i=0;i<N;i++)
179+
z[i] += gru->bias[3*N + i];
180+
}
176181
gemm_accum(z, gru->input_weights, N, M, stride, input);
177182
gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
178183
compute_activation(z, z, N, ACTIVATION_SIGMOID);
179184

180185
/* Compute reset gate. */
181186
for (i=0;i<N;i++)
182187
r[i] = gru->bias[N + i];
188+
if (gru->reset_after)
189+
{
190+
for (i=0;i<N;i++)
191+
r[i] += gru->bias[4*N + i];
192+
}
183193
gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
184194
gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
185195
compute_activation(r, r, N, ACTIVATION_SIGMOID);
@@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
189199
h[i] = gru->bias[2*N + i];
190200
if (gru->reset_after)
191201
{
192-
/* WARNING: The reset_after version was never tested. */
193-
RNN_CLEAR(tmp, N);
202+
for (i=0;i<N;i++)
203+
tmp[i] = gru->bias[5*N + i];
194204
gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state);
195205
for (i=0;i<N;i++)
196206
h[i] += tmp[i] * r[i];

0 commit comments

Comments
 (0)