@@ -173,13 +173,23 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
173173 /* Compute update gate. */
174174 for (i = 0 ;i < N ;i ++ )
175175 z [i ] = gru -> bias [i ];
176+ if (gru -> reset_after )
177+ {
178+ for (i = 0 ;i < N ;i ++ )
179+ z [i ] += gru -> bias [3 * N + i ];
180+ }
176181 gemm_accum (z , gru -> input_weights , N , M , stride , input );
177182 gemm_accum (z , gru -> recurrent_weights , N , N , stride , state );
178183 compute_activation (z , z , N , ACTIVATION_SIGMOID );
179184
180185 /* Compute reset gate. */
181186 for (i = 0 ;i < N ;i ++ )
182187 r [i ] = gru -> bias [N + i ];
188+ if (gru -> reset_after )
189+ {
190+ for (i = 0 ;i < N ;i ++ )
191+ r [i ] += gru -> bias [4 * N + i ];
192+ }
183193 gemm_accum (r , & gru -> input_weights [N ], N , M , stride , input );
184194 gemm_accum (r , & gru -> recurrent_weights [N ], N , N , stride , state );
185195 compute_activation (r , r , N , ACTIVATION_SIGMOID );
@@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
189199 h [i ] = gru -> bias [2 * N + i ];
190200 if (gru -> reset_after )
191201 {
192- /* WARNING: The reset_after version was never tested. */
193- RNN_CLEAR ( tmp , N ) ;
202+ for ( i = 0 ; i < N ; i ++ )
203+ tmp [ i ] = gru -> bias [ 5 * N + i ] ;
194204 gemm_accum (tmp , & gru -> recurrent_weights [2 * N ], N , N , stride , state );
195205 for (i = 0 ;i < N ;i ++ )
196206 h [i ] += tmp [i ] * r [i ];
0 commit comments