@@ -76,35 +76,38 @@ static OPUS_INLINE float relu(float x)
7676 return x < 0 ? 0 : x ;
7777}
7878
79+ static void faxpy (float * restrict a , const rnn_weight * restrict b , int k , float u )
80+ {
81+ if (u == 0.0 ) return ;
82+ for (int idx = 0 ; idx < k ; idx ++ )
83+ a [idx ] += b [idx ] * u ;
84+ }
85+
7986void compute_dense (const DenseLayer * layer , float * output , const float * input )
8087{
8188 int i , j ;
8289 int N , M ;
83- int stride ;
8490 M = layer -> nb_inputs ;
8591 N = layer -> nb_neurons ;
86- stride = N ;
87- for (i = 0 ;i < N ;i ++ )
88- {
89- /* Compute update gate. */
90- float sum = layer -> bias [i ];
91- for (j = 0 ;j < M ;j ++ )
92- sum += layer -> input_weights [j * stride + i ]* input [j ];
93- output [i ] = WEIGHTS_SCALE * sum ;
94- }
92+ const rnn_weight * ip = layer -> input_weights ;
93+ /* Compute update gate. */
94+ for (i = 0 ; i < N ; i ++ )
95+ output [i ] = layer -> bias [i ];
96+ for (j = 0 ;j < M ;j ++ ,ip += N )
97+ faxpy (output , ip , N , input [j ]);
9598 switch (layer -> activation ) {
9699 case ACTIVATION_SIGMOID :
97100 for (i = 0 ;i < N ;i ++ )
98- output [i ] = sigmoid_approx (output [i ]);
101+ output [i ] = sigmoid_approx (WEIGHTS_SCALE * output [i ]);
99102 break ;
100103 case ACTIVATION_TANH :
101104 for (i = 0 ;i < N ;i ++ )
102- output [i ] = tansig_approx (output [i ]);
105+ output [i ] = tansig_approx (WEIGHTS_SCALE * output [i ]);
103106 break ;
104107 default :
105108 case ACTIVATION_RELU :
106109 for (i = 0 ;i < N ;i ++ )
107- output [i ] = relu (output [i ]);
110+ output [i ] = relu (WEIGHTS_SCALE * output [i ]);
108111 break ;
109112 }
110113}
@@ -120,44 +123,49 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
120123 M = gru -> nb_inputs ;
121124 N = gru -> nb_neurons ;
122125 stride = 3 * N ;
123- for (i = 0 ;i < N ;i ++ )
124- {
125- /* Compute update gate. */
126- float sum = gru -> bias [i ];
127- for (j = 0 ;j < M ;j ++ )
128- sum += gru -> input_weights [j * stride + i ]* input [j ];
129- for (j = 0 ;j < N ;j ++ )
130- sum += gru -> recurrent_weights [j * stride + i ]* state [j ];
131- z [i ] = sigmoid_approx (WEIGHTS_SCALE * sum );
132- }
133- for (i = 0 ;i < N ;i ++ )
134- {
135- /* Compute reset gate. */
136- float sum = gru -> bias [N + i ];
137- for (j = 0 ;j < M ;j ++ )
138- sum += gru -> input_weights [N + j * stride + i ]* input [j ];
139- for (j = 0 ;j < N ;j ++ )
140- sum += gru -> recurrent_weights [N + j * stride + i ]* state [j ];
141- r [i ] = sigmoid_approx (WEIGHTS_SCALE * sum );
126+ const rnn_weight * ip = gru -> input_weights ;
127+ const rnn_weight * rp = gru -> recurrent_weights ;
128+ /* Compute update gate. */
129+ for (i = 0 ; i < N ; i ++ )
130+ z [i ] = gru -> bias [i ];
131+ for (j = 0 ;j < M ;j ++ ,ip += stride )
132+ faxpy (z , ip , N , input [j ]);
133+ for (j = 0 ;j < N ;j ++ ,rp += stride )
134+ faxpy (z , rp , N , state [j ]);
135+ for (i = 0 ; i < N ; i ++ )
136+ z [i ] = sigmoid_approx (WEIGHTS_SCALE * z [i ]);
137+ /* Compute reset gate. */
138+ for (i = 0 ; i < N ; i ++ )
139+ r [i ] = gru -> bias [N + i ];
140+ ip = gru -> input_weights + N ;
141+ rp = gru -> recurrent_weights + N ;
142+ for (j = 0 ;j < M ;j ++ ,ip += stride )
143+ faxpy (r , ip , N , input [j ]);
144+ for (j = 0 ;j < N ;j ++ ,rp += stride )
145+ faxpy (r , rp , N , state [j ]);
146+ for (i = 0 ; i < N ; i ++ )
147+ r [i ] = sigmoid_approx (WEIGHTS_SCALE * r [i ]);
148+
149+ /* Compute output. */
150+ for (i = 0 ; i < N ; i ++ )
151+ h [i ] = gru -> bias [2 * N + i ];
152+ ip = gru -> input_weights + 2 * N ;
153+ rp = gru -> recurrent_weights + 2 * N ;
154+ for (j = 0 ;j < M ;j ++ ,ip += stride )
155+ faxpy (h , ip , N , input [j ]);
156+ for (j = 0 ;j < N ;j ++ ,rp += stride )
157+ faxpy (h , rp , N , r [j ]* state [j ]);
158+ for (i = 0 ;i < N ;i ++ ) {
159+ switch (gru -> activation ) {
160+ case ACTIVATION_SIGMOID : h [i ] = sigmoid_approx (WEIGHTS_SCALE * h [i ]);break ;
161+ case ACTIVATION_TANH : h [i ] = tansig_approx (WEIGHTS_SCALE * h [i ]); break ;
162+ default :
163+ case ACTIVATION_RELU : h [i ] = relu (WEIGHTS_SCALE * h [i ]); break ;
164+ }
165+ h [i ] = z [i ]* state [i ] + (1 - z [i ])* h [i ];
142166 }
143167 for (i = 0 ;i < N ;i ++ )
144- {
145- /* Compute output. */
146- float sum = gru -> bias [2 * N + i ];
147- for (j = 0 ;j < M ;j ++ )
148- sum += gru -> input_weights [2 * N + j * stride + i ]* input [j ];
149- for (j = 0 ;j < N ;j ++ )
150- sum += gru -> recurrent_weights [2 * N + j * stride + i ]* state [j ]* r [j ];
151- switch (gru -> activation ) {
152- case ACTIVATION_SIGMOID : sum = sigmoid_approx (WEIGHTS_SCALE * sum );break ;
153- case ACTIVATION_TANH : sum = tansig_approx (WEIGHTS_SCALE * sum ); break ;
154- default :
155- case ACTIVATION_RELU : sum = relu (WEIGHTS_SCALE * sum ); break ;
156- }
157- h [i ] = z [i ]* state [i ] + (1 - z [i ])* sum ;
158- }
159- for (i = 0 ;i < N ;i ++ )
160- state [i ] = h [i ];
168+ state [i ] = h [i ];
161169}
162170
163171#define INPUT_SIZE 42
0 commit comments