Skip to content

Commit f11ee2d

Browse files
committed
added missing keyword arguments, changed shape to ps.shape, changed gate_first_n default to None
1 parent f5855eb commit f11ee2d

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class HighwayFlow(tfb.Bijector):
162162
def __init__(self, residual_fraction, activation_fn, bias,
163163
upper_diagonal_weights_matrix,
164164
lower_diagonal_weights_matrix,
165-
gate_first_n,
165+
gate_first_n=None,
166166
validate_args=False,
167167
name=None):
168168
"""Initializes the HighwayFlow.
@@ -258,8 +258,7 @@ def _convex_update(self, weights_matrix):
258258
num_columns=self.width,
259259
dtype=self.dtype),
260260
tf.zeros([self.num_ungated, self.width], dtype=self.dtype)],
261-
axis=0) + tf.concat([(
262-
1. - self.residual_fraction) * tf.ones(
261+
axis=0) + tf.concat([(1. - self.residual_fraction) * tf.ones(
263262
self.gate_first_n, dtype=self.dtype),
264263
tf.ones(self.num_ungated, dtype=self.dtype)],
265264
axis=0) * weights_matrix
@@ -294,7 +293,7 @@ def _augmented_forward(self, x):
294293
# Log determinant term from the upper matrix. Note that the log determinant
295294
# of the lower matrix is zero.
296295

297-
fldj = tf.zeros(x.shape[:-1], dtype=self.dtype) + tf.reduce_sum(
296+
fldj = tf.zeros(ps.shape(x)[:-1], dtype=self.dtype) + tf.reduce_sum(
298297
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
299298
self.gate_first_n, dtype=self.dtype),
300299
tf.zeros(self.num_ungated, dtype=self.dtype)],
@@ -317,7 +316,7 @@ def _augmented_forward(self, x):
317316

318317
if self.activation_fn:
319318
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x[0])),
320-
-1)
319+
axis=-1)
321320
x = tf.concat([(self.residual_fraction) * tf.ones(
322321
self.gate_first_n, dtype=self.dtype),
323322
tf.zeros(self.num_ungated, dtype=self.dtype)],
@@ -340,7 +339,7 @@ def _augmented_inverse(self, y):
340339
determinant of the jacobian.
341340
"""
342341

343-
ildj = tf.zeros(y.shape[:-1], dtype=self.dtype) - tf.reduce_sum(
342+
ildj = tf.zeros(ps.shape(y)[:-1], dtype=self.dtype) - tf.reduce_sum(
344343
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
345344
self.gate_first_n, dtype=self.dtype),
346345
tf.zeros(self.num_ungated, dtype=self.dtype)],
@@ -354,7 +353,7 @@ def _augmented_inverse(self, y):
354353
if self.activation_fn:
355354
y = self._inverse_of_softplus(y)
356355
ildj -= tf.reduce_sum(tf.math.log(self._derivative_of_softplus(y)),
357-
-1)
356+
axis=-1)
358357

359358
y = y[..., tf.newaxis]
360359

@@ -368,7 +367,7 @@ def _augmented_inverse(self, y):
368367
y = tf.linalg.triangular_solve(
369368
self._convex_update(self.lower_diagonal_weights_matrix), y)
370369

371-
return tf.squeeze(y, -1), {'ildj': ildj, 'fldj': -ildj}
370+
return tf.squeeze(y, axis=-1), {'ildj': ildj, 'fldj': -ildj}
372371

373372
def _forward(self, x):
374373
y, _ = self._augmented_forward(x)

tensorflow_probability/python/experimental/bijectors/highway_flow_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
tfb = tfp.bijectors
2323
tfd = tfp.distributions
2424

25-
2625
@test_util.test_all_tf_execution_regimes
2726
class HighwayFlowTests(test_util.TestCase):
2827

0 commit comments

Comments
 (0)