Skip to content

Commit a931441

Browse files
committed
using self.num_ungated
1 parent 1ad515a commit a931441

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def __init__(self, residual_fraction, activation_fn, bias,
190190
if gate_first_n:
191191
self._gate_first_n = gate_first_n if gate_first_n else self.width
192192

193+
self._num_ungated = self.self.width - self.gate_first_n
194+
193195

194196

195197
super(HighwayFlow, self).__init__(
@@ -226,33 +228,37 @@ def activation_fn(self):
226228
def gate_first_n(self):
227229
return self._gate_first_n
228230

231+
@property
232+
def num_ungated(self):
233+
return self._num_ungated
234+
229235
def _derivative_of_softplus(self, x):
230236
return tf.concat([(self.residual_fraction) * tf.ones(
231-
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
237+
self.gate_first_n), tf.zeros(self.num_ungated)],
232238
axis=0) + (
233239
tf.concat([(1. - self.residual_fraction) * tf.ones(
234-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
240+
self.gate_first_n), tf.ones(self.num_ungated)],
235241
axis=0)) * tf.math.sigmoid(x)
236242

237243
def _convex_update(self, weights_matrix):
238244
return tf.concat(
239245
[self.residual_fraction * tf.eye(num_rows=self.gate_first_n,
240246
num_columns=self.width),
241-
tf.zeros([self.width - self.gate_first_n, self.width])],
247+
tf.zeros([self.num_ungated, self.width])],
242248
axis=0) + tf.concat([(
243249
1. - self.residual_fraction) * tf.ones(
244-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
250+
self.gate_first_n), tf.ones(self.num_ungated)],
245251
axis=0) * weights_matrix
246252

247253
def _inverse_of_softplus(self, y, n=20):
248254
"""Inverse of the activation layer with softplus using Newton iteration."""
249255
x = tf.ones(y.shape)
250256
for _ in range(n):
251257
x = x - (tf.concat([(self.residual_fraction) * tf.ones(
252-
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
258+
self.gate_first_n), tf.zeros(self.num_ungated)],
253259
axis=0) * x + tf.concat(
254260
[(1. - self.residual_fraction) * tf.ones(
255-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
261+
self.gate_first_n), tf.ones(self.num_ungated)],
256262
axis=0) * tf.math.softplus(
257263
x) - y) / (
258264
self._derivative_of_softplus(x))
@@ -274,11 +280,11 @@ def _augmented_forward(self, x):
274280

275281
fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum(
276282
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
277-
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
283+
self.gate_first_n), tf.zeros(self.num_ungated)],
278284
axis=0) + (
279285
tf.concat([(1. - self.residual_fraction) * tf.ones(
280286
self.gate_first_n),
281-
tf.ones(self.width - self.gate_first_n)],
287+
tf.ones(self.num_ungated)],
282288
axis=0)) * tf.linalg.diag_part(
283289
self.upper_diagonal_weights_matrix)))
284290
x = x[tf.newaxis, ...]
@@ -288,17 +294,17 @@ def _augmented_forward(self, x):
288294
self._convex_update(self.upper_diagonal_weights_matrix),
289295
x, transpose_a=True)
290296
x += (tf.concat([(1. - self.residual_fraction) * tf.ones(
291-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
297+
self.gate_first_n), tf.ones(self.num_ungated)],
292298
axis=0) * self.bias)[tf.newaxis, ...]
293299

294300
if self.activation_fn:
295301
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x[0])),
296302
-1)
297303
x = tf.concat([(self.residual_fraction) * tf.ones(
298-
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
304+
self.gate_first_n), tf.zeros(self.num_ungated)],
299305
axis=0) * x + tf.concat(
300306
[(1. - self.residual_fraction) * tf.ones(
301-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
307+
self.gate_first_n), tf.ones(self.num_ungated)],
302308
axis=0) * tf.nn.softplus(x)
303309

304310
return tf.squeeze(x, 0), {'ildj': -fldj, 'fldj': fldj}
@@ -316,10 +322,10 @@ def _augmented_inverse(self, y):
316322

317323
ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum(
318324
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
319-
self.gate_first_n), tf.zeros(self.width - self.gate_first_n)],
325+
self.gate_first_n), tf.zeros(self.num_ungated)],
320326
axis=0) + tf.concat(
321327
[(1. - self.residual_fraction) * tf.ones(
322-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
328+
self.gate_first_n), tf.ones(self.num_ungated)],
323329
axis=0) * tf.linalg.diag_part(
324330
self.upper_diagonal_weights_matrix)))
325331

@@ -331,7 +337,7 @@ def _augmented_inverse(self, y):
331337
y = y[..., tf.newaxis]
332338

333339
y = y - (tf.concat([(1. - self.residual_fraction) * tf.ones(
334-
self.gate_first_n), tf.ones(self.width - self.gate_first_n)],
340+
self.gate_first_n), tf.ones(self.num_ungated)],
335341
axis=0) * self.bias)[..., tf.newaxis]
336342
y = tf.linalg.triangular_solve(
337343
self._convex_update(self.upper_diagonal_weights_matrix), y,

0 commit comments

Comments
 (0)