@@ -162,7 +162,7 @@ class HighwayFlow(tfb.Bijector):
162
162
def __init__ (self , residual_fraction , activation_fn , bias ,
163
163
upper_diagonal_weights_matrix ,
164
164
lower_diagonal_weights_matrix ,
165
- gate_first_n ,
165
+ gate_first_n = None ,
166
166
validate_args = False ,
167
167
name = None ):
168
168
"""Initializes the HighwayFlow.
@@ -258,8 +258,7 @@ def _convex_update(self, weights_matrix):
258
258
num_columns = self .width ,
259
259
dtype = self .dtype ),
260
260
tf .zeros ([self .num_ungated , self .width ], dtype = self .dtype )],
261
- axis = 0 ) + tf .concat ([(
262
- 1. - self .residual_fraction ) * tf .ones (
261
+ axis = 0 ) + tf .concat ([(1. - self .residual_fraction ) * tf .ones (
263
262
self .gate_first_n , dtype = self .dtype ),
264
263
tf .ones (self .num_ungated , dtype = self .dtype )],
265
264
axis = 0 ) * weights_matrix
@@ -294,7 +293,7 @@ def _augmented_forward(self, x):
294
293
# Log determinant term from the upper matrix. Note that the log determinant
295
294
# of the lower matrix is zero.
296
295
297
- fldj = tf .zeros (x .shape [:- 1 ], dtype = self .dtype ) + tf .reduce_sum (
296
+ fldj = tf .zeros (ps .shape ( x ) [:- 1 ], dtype = self .dtype ) + tf .reduce_sum (
298
297
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
299
298
self .gate_first_n , dtype = self .dtype ),
300
299
tf .zeros (self .num_ungated , dtype = self .dtype )],
@@ -317,7 +316,7 @@ def _augmented_forward(self, x):
317
316
318
317
if self .activation_fn :
319
318
fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (x [0 ])),
320
- - 1 )
319
+ axis = - 1 )
321
320
x = tf .concat ([(self .residual_fraction ) * tf .ones (
322
321
self .gate_first_n , dtype = self .dtype ),
323
322
tf .zeros (self .num_ungated , dtype = self .dtype )],
@@ -340,7 +339,7 @@ def _augmented_inverse(self, y):
340
339
determinant of the jacobian.
341
340
"""
342
341
343
- ildj = tf .zeros (y .shape [:- 1 ], dtype = self .dtype ) - tf .reduce_sum (
342
+ ildj = tf .zeros (ps .shape ( y ) [:- 1 ], dtype = self .dtype ) - tf .reduce_sum (
344
343
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
345
344
self .gate_first_n , dtype = self .dtype ),
346
345
tf .zeros (self .num_ungated , dtype = self .dtype )],
@@ -354,7 +353,7 @@ def _augmented_inverse(self, y):
354
353
if self .activation_fn :
355
354
y = self ._inverse_of_softplus (y )
356
355
ildj -= tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (y )),
357
- - 1 )
356
+ axis = - 1 )
358
357
359
358
y = y [..., tf .newaxis ]
360
359
@@ -368,7 +367,7 @@ def _augmented_inverse(self, y):
368
367
y = tf .linalg .triangular_solve (
369
368
self ._convex_update (self .lower_diagonal_weights_matrix ), y )
370
369
371
- return tf .squeeze (y , - 1 ), {'ildj' : ildj , 'fldj' : - ildj }
370
+ return tf .squeeze (y , axis = - 1 ), {'ildj' : ildj , 'fldj' : - ildj }
372
371
373
372
def _forward (self , x ):
374
373
y , _ = self ._augmented_forward (x )
0 commit comments