@@ -190,6 +190,8 @@ def __init__(self, residual_fraction, activation_fn, bias,
190
190
if gate_first_n :
191
191
self ._gate_first_n = gate_first_n if gate_first_n else self .width
192
192
193
+ self ._num_ungated = self .self .width - self .gate_first_n
194
+
193
195
194
196
195
197
super (HighwayFlow , self ).__init__ (
@@ -226,33 +228,37 @@ def activation_fn(self):
226
228
def gate_first_n (self ):
227
229
return self ._gate_first_n
228
230
231
+ @property
232
+ def num_ungated (self ):
233
+ return self ._num_ungated
234
+
229
235
def _derivative_of_softplus (self , x ):
230
236
return tf .concat ([(self .residual_fraction ) * tf .ones (
231
- self .gate_first_n ), tf .zeros (self .width - self . gate_first_n )],
237
+ self .gate_first_n ), tf .zeros (self .num_ungated )],
232
238
axis = 0 ) + (
233
239
tf .concat ([(1. - self .residual_fraction ) * tf .ones (
234
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
240
+ self .gate_first_n ), tf .ones (self .num_ungated )],
235
241
axis = 0 )) * tf .math .sigmoid (x )
236
242
237
243
def _convex_update (self , weights_matrix ):
238
244
return tf .concat (
239
245
[self .residual_fraction * tf .eye (num_rows = self .gate_first_n ,
240
246
num_columns = self .width ),
241
- tf .zeros ([self .width - self . gate_first_n , self .width ])],
247
+ tf .zeros ([self .num_ungated , self .width ])],
242
248
axis = 0 ) + tf .concat ([(
243
249
1. - self .residual_fraction ) * tf .ones (
244
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
250
+ self .gate_first_n ), tf .ones (self .num_ungated )],
245
251
axis = 0 ) * weights_matrix
246
252
247
253
def _inverse_of_softplus (self , y , n = 20 ):
248
254
"""Inverse of the activation layer with softplus using Newton iteration."""
249
255
x = tf .ones (y .shape )
250
256
for _ in range (n ):
251
257
x = x - (tf .concat ([(self .residual_fraction ) * tf .ones (
252
- self .gate_first_n ), tf .zeros (self .width - self . gate_first_n )],
258
+ self .gate_first_n ), tf .zeros (self .num_ungated )],
253
259
axis = 0 ) * x + tf .concat (
254
260
[(1. - self .residual_fraction ) * tf .ones (
255
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
261
+ self .gate_first_n ), tf .ones (self .num_ungated )],
256
262
axis = 0 ) * tf .math .softplus (
257
263
x ) - y ) / (
258
264
self ._derivative_of_softplus (x ))
@@ -274,11 +280,11 @@ def _augmented_forward(self, x):
274
280
275
281
fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
276
282
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
277
- self .gate_first_n ), tf .zeros (self .width - self . gate_first_n )],
283
+ self .gate_first_n ), tf .zeros (self .num_ungated )],
278
284
axis = 0 ) + (
279
285
tf .concat ([(1. - self .residual_fraction ) * tf .ones (
280
286
self .gate_first_n ),
281
- tf .ones (self .width - self . gate_first_n )],
287
+ tf .ones (self .num_ungated )],
282
288
axis = 0 )) * tf .linalg .diag_part (
283
289
self .upper_diagonal_weights_matrix )))
284
290
x = x [tf .newaxis , ...]
@@ -288,17 +294,17 @@ def _augmented_forward(self, x):
288
294
self ._convex_update (self .upper_diagonal_weights_matrix ),
289
295
x , transpose_a = True )
290
296
x += (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
291
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
297
+ self .gate_first_n ), tf .ones (self .num_ungated )],
292
298
axis = 0 ) * self .bias )[tf .newaxis , ...]
293
299
294
300
if self .activation_fn :
295
301
fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (x [0 ])),
296
302
- 1 )
297
303
x = tf .concat ([(self .residual_fraction ) * tf .ones (
298
- self .gate_first_n ), tf .zeros (self .width - self . gate_first_n )],
304
+ self .gate_first_n ), tf .zeros (self .num_ungated )],
299
305
axis = 0 ) * x + tf .concat (
300
306
[(1. - self .residual_fraction ) * tf .ones (
301
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
307
+ self .gate_first_n ), tf .ones (self .num_ungated )],
302
308
axis = 0 ) * tf .nn .softplus (x )
303
309
304
310
return tf .squeeze (x , 0 ), {'ildj' : - fldj , 'fldj' : fldj }
@@ -316,10 +322,10 @@ def _augmented_inverse(self, y):
316
322
317
323
ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
318
324
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
319
- self .gate_first_n ), tf .zeros (self .width - self . gate_first_n )],
325
+ self .gate_first_n ), tf .zeros (self .num_ungated )],
320
326
axis = 0 ) + tf .concat (
321
327
[(1. - self .residual_fraction ) * tf .ones (
322
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
328
+ self .gate_first_n ), tf .ones (self .num_ungated )],
323
329
axis = 0 ) * tf .linalg .diag_part (
324
330
self .upper_diagonal_weights_matrix )))
325
331
@@ -331,7 +337,7 @@ def _augmented_inverse(self, y):
331
337
y = y [..., tf .newaxis ]
332
338
333
339
y = y - (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
334
- self .gate_first_n ), tf .ones (self .width - self . gate_first_n )],
340
+ self .gate_first_n ), tf .ones (self .num_ungated )],
335
341
axis = 0 ) * self .bias )[..., tf .newaxis ]
336
342
y = tf .linalg .triangular_solve (
337
343
self ._convex_update (self .upper_diagonal_weights_matrix ), y ,
0 commit comments