@@ -271,12 +271,6 @@ def _augmented_forward(self, x):
271
271
# Log determinant term from the upper matrix. Note that the log determinant
272
272
# of the lower matrix is zero.
273
273
274
- added_batch = False
275
- if len (x .shape ) <= 1 :
276
- if len (x .shape ) == 0 :
277
- x = tf .reshape (x , - 1 )
278
- added_batch = True
279
- x = tf .expand_dims (x , 0 )
280
274
fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
281
275
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
282
276
self .gate_first_n ), tf .zeros (self .width - self .gate_first_n )],
@@ -286,27 +280,26 @@ def _augmented_forward(self, x):
286
280
tf .ones (self .width - self .gate_first_n )],
287
281
axis = 0 )) * tf .linalg .diag_part (
288
282
self .upper_diagonal_weights_matrix )))
283
+ x = x [tf .newaxis , ...]
289
284
x = tf .linalg .matvec (
290
285
self ._convex_update (self .lower_diagonal_weights_matrix ), x )
291
- x = tf .linalg .matvec (tf .transpose (
292
- self ._convex_update (self .upper_diagonal_weights_matrix )),
293
- x )
294
- x += tf .concat ([(1. - self .residual_fraction ) * tf .ones (
286
+ x = tf .linalg .matvec (self ._convex_update (self .upper_diagonal_weights_matrix ),
287
+ x , transpose_a = True )
288
+ x += (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
295
289
self .gate_first_n ), tf .ones (self .width - self .gate_first_n )],
296
- axis = 0 ) * self .bias
290
+ axis = 0 ) * self .bias )[ tf . newaxis , ...]
297
291
298
292
if self .activation_fn :
299
- fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (x )),
293
+ fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (x [ 0 ] )),
300
294
- 1 )
301
295
x = tf .concat ([(self .residual_fraction ) * tf .ones (
302
296
self .gate_first_n ), tf .zeros (self .width - self .gate_first_n )],
303
297
axis = 0 ) * x + tf .concat (
304
298
[(1. - self .residual_fraction ) * tf .ones (
305
299
self .gate_first_n ), tf .ones (self .width - self .gate_first_n )],
306
300
axis = 0 ) * tf .nn .softplus (x )
307
- if added_batch :
308
- x = tf .squeeze (x , 0 )
309
- return x , {'ildj' : - fldj , 'fldj' : fldj }
301
+
302
+ return tf .squeeze (x , 0 ), {'ildj' : - fldj , 'fldj' : fldj }
310
303
311
304
def _augmented_inverse (self , y ):
312
305
"""Computes inverse and inverse_log_det_jacobian transformations.
@@ -319,12 +312,6 @@ def _augmented_inverse(self, y):
319
312
determinant of the jacobian.
320
313
"""
321
314
322
- added_batch = False
323
- if len (y .shape ) <= 1 :
324
- if len (y .shape ) == 0 :
325
- y = tf .reshape (y , - 1 )
326
- added_batch = True
327
- y = tf .expand_dims (y , 0 )
328
315
ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
329
316
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
330
317
self .gate_first_n ), tf .zeros (self .width - self .gate_first_n )],
@@ -333,23 +320,25 @@ def _augmented_inverse(self, y):
333
320
self .gate_first_n ), tf .ones (self .width - self .gate_first_n )],
334
321
axis = 0 ) * tf .linalg .diag_part (
335
322
self .upper_diagonal_weights_matrix )))
323
+
324
+
336
325
if self .activation_fn :
337
326
y = self ._inverse_of_softplus (y )
338
327
ildj -= tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (y )),
339
328
- 1 )
340
329
341
- y = y - tf .concat ([(1. - self .residual_fraction ) * tf .ones (
330
+ y = y [..., tf .newaxis ]
331
+
332
+ y = y - (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
342
333
self .gate_first_n ), tf .ones (self .width - self .gate_first_n )],
343
- axis = 0 ) * self .bias
344
- y = tf .linalg .triangular_solve (tf .transpose (
345
- self ._convex_update (self .upper_diagonal_weights_matrix )),
346
- tf .linalg .matrix_transpose (y ),
347
- lower = False )
348
- y = tf .linalg .matrix_transpose (tf .linalg .triangular_solve (
349
- self ._convex_update (self .lower_diagonal_weights_matrix ), y ))
350
- if added_batch :
351
- y = tf .squeeze (y , 0 )
352
- return y , {'ildj' : ildj , 'fldj' : - ildj }
334
+ axis = 0 ) * self .bias )[..., tf .newaxis ]
335
+ y = tf .linalg .triangular_solve (
336
+ self ._convex_update (self .upper_diagonal_weights_matrix ), y ,
337
+ lower = True , adjoint = True )
338
+ y = tf .linalg .triangular_solve (
339
+ self ._convex_update (self .lower_diagonal_weights_matrix ), y )
340
+
341
+ return tf .squeeze (y , - 1 ), {'ildj' : ildj , 'fldj' : - ildj }
353
342
354
343
def _forward (self , x ):
355
344
y , _ = self ._augmented_forward (x )
0 commit comments