Skip to content

Commit 1d294b8

Browse files
committed
removed transpose and added_batch logic
1 parent 2e8e8bd commit 1d294b8

File tree

1 file changed

+21
-32
lines changed

1 file changed

+21
-32
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,6 @@ def _augmented_forward(self, x):
271271
# Log determinant term from the upper matrix. Note that the log determinant
272272
# of the lower matrix is zero.
273273

274-
added_batch = False
275-
if len(x.shape) <= 1:
276-
if len(x.shape) == 0:
277-
x = tf.reshape(x, -1)
278-
added_batch = True
279-
x = tf.expand_dims(x, 0)
280274
fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum(
281275
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
282276
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
@@ -286,27 +280,26 @@ def _augmented_forward(self, x):
286280
tf.ones(self.width - self.gate_first_n)],
287281
axis=0)) * tf.linalg.diag_part(
288282
self.upper_diagonal_weights_matrix)))
283+
x = x[tf.newaxis, ...]
289284
x = tf.linalg.matvec(
290285
self._convex_update(self.lower_diagonal_weights_matrix), x)
291-
x = tf.linalg.matvec(tf.transpose(
292-
self._convex_update(self.upper_diagonal_weights_matrix)),
293-
x)
294-
x += tf.concat([(1. - self.residual_fraction) * tf.ones(
286+
x = tf.linalg.matvec(self._convex_update(self.upper_diagonal_weights_matrix),
287+
x, transpose_a=True)
288+
x += (tf.concat([(1. - self.residual_fraction) * tf.ones(
295289
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
296-
axis=0) * self.bias
290+
axis=0) * self.bias)[tf.newaxis, ...]
297291

298292
if self.activation_fn:
299-
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x)),
293+
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x[0])),
300294
-1)
301295
x = tf.concat([(self.residual_fraction) * tf.ones(
302296
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
303297
axis=0) * x + tf.concat(
304298
[(1. - self.residual_fraction) * tf.ones(
305299
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
306300
axis=0) * tf.nn.softplus(x)
307-
if added_batch:
308-
x = tf.squeeze(x, 0)
309-
return x, {'ildj': -fldj, 'fldj': fldj}
301+
302+
return tf.squeeze(x, 0), {'ildj': -fldj, 'fldj': fldj}
310303

311304
def _augmented_inverse(self, y):
312305
"""Computes inverse and inverse_log_det_jacobian transformations.
@@ -319,12 +312,6 @@ def _augmented_inverse(self, y):
319312
determinant of the jacobian.
320313
"""
321314

322-
added_batch = False
323-
if len(y.shape) <= 1:
324-
if len(y.shape) == 0:
325-
y = tf.reshape(y, -1)
326-
added_batch = True
327-
y = tf.expand_dims(y, 0)
328315
ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum(
329316
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
330317
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
@@ -333,23 +320,25 @@ def _augmented_inverse(self, y):
333320
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
334321
axis=0) * tf.linalg.diag_part(
335322
self.upper_diagonal_weights_matrix)))
323+
324+
336325
if self.activation_fn:
337326
y = self._inverse_of_softplus(y)
338327
ildj -= tf.reduce_sum(tf.math.log(self._derivative_of_softplus(y)),
339328
-1)
340329

341-
y = y - tf.concat([(1. - self.residual_fraction) * tf.ones(
330+
y = y[..., tf.newaxis]
331+
332+
y = y - (tf.concat([(1. - self.residual_fraction) * tf.ones(
342333
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
343-
axis=0) * self.bias
344-
y = tf.linalg.triangular_solve(tf.transpose(
345-
self._convex_update(self.upper_diagonal_weights_matrix)),
346-
tf.linalg.matrix_transpose(y),
347-
lower=False)
348-
y = tf.linalg.matrix_transpose(tf.linalg.triangular_solve(
349-
self._convex_update(self.lower_diagonal_weights_matrix), y))
350-
if added_batch:
351-
y = tf.squeeze(y, 0)
352-
return y, {'ildj': ildj, 'fldj': -ildj}
334+
axis=0) * self.bias)[..., tf.newaxis]
335+
y = tf.linalg.triangular_solve(
336+
self._convex_update(self.upper_diagonal_weights_matrix), y,
337+
lower=True, adjoint=True)
338+
y = tf.linalg.triangular_solve(
339+
self._convex_update(self.lower_diagonal_weights_matrix), y)
340+
341+
return tf.squeeze(y, -1), {'ildj': ildj, 'fldj': -ildj}
353342

354343
def _forward(self, x):
355344
y, _ = self._augmented_forward(x)

0 commit comments

Comments
 (0)