Skip to content

Commit 5699521

Browse files
committed
added dtype and fixed self._gate_first_n
1 parent a931441 commit 5699521

File tree

1 file changed

+53
-30
lines changed

1 file changed

+53
-30
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from tensorflow_probability.python import bijectors as tfb
2121
from tensorflow_probability.python import util
2222
from tensorflow_probability.python.internal import cache_util
23+
from tensorflow_probability.python.internal import dtype_util
24+
from tensorflow_probability.python.internal import prefer_static as ps
2325
from tensorflow_probability.python.internal import samplers
26+
from tensorflow_probability.python.internal import tensor_util
2427

2528

2629
def build_highway_flow_layer(width,
@@ -178,26 +181,33 @@ def __init__(self, residual_fraction, activation_fn, bias,
178181
"""
179182
parameters = dict(locals())
180183
name = name or 'highway_flow'
184+
dtype = dtype_util.common_dtype(
185+
[residual_fraction, bias, upper_diagonal_weights_matrix,
186+
lower_diagonal_weights_matrix], dtype_hint=tf.float32)
181187
with tf.name_scope(name) as name:
182-
self._width = tf.shape(bias)[-1]
183-
self._bias = bias
184-
self._residual_fraction = residual_fraction
188+
self._width = ps.shape(bias)[-1]
189+
self._bias = tensor_util.convert_nonref_to_tensor(bias, dtype=dtype,
190+
name='bias')
191+
self._residual_fraction = tensor_util.convert_nonref_to_tensor(
192+
residual_fraction, dtype=dtype, name='residual_fraction')
185193
# The upper matrix is still lower triangular, transpose is done in
186194
# _inverse and _forwars metowds, within matvec.
187-
self._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
188-
self._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
195+
self._upper_diagonal_weights_matrix = tensor_util.convert_nonref_to_tensor(
196+
upper_diagonal_weights_matrix, dtype=dtype,
197+
name='upper_diagonal_weights_matrix')
198+
self._lower_diagonal_weights_matrix = tensor_util.convert_nonref_to_tensor(
199+
lower_diagonal_weights_matrix, dtype=dtype,
200+
name='lower_diagonal_weights_matrix')
189201
self._activation_fn = activation_fn
190-
if gate_first_n:
191-
self._gate_first_n = gate_first_n if gate_first_n else self.width
192-
193-
self._num_ungated = self.self.width - self.gate_first_n
194-
202+
self._gate_first_n = gate_first_n if gate_first_n else self.width
195203

204+
self._num_ungated = self.width - self.gate_first_n
196205

197206
super(HighwayFlow, self).__init__(
198207
validate_args=validate_args,
199208
forward_min_event_ndims=1,
200209
parameters=parameters,
210+
dtype=dtype,
201211
name=name)
202212

203213
@property
@@ -234,31 +244,37 @@ def num_ungated(self):
234244

235245
def _derivative_of_softplus(self, x):
236246
return tf.concat([(self.residual_fraction) * tf.ones(
237-
self.gate_first_n), tf.zeros(self.num_ungated)],
247+
self.gate_first_n, dtype=self.dtype),
248+
tf.zeros(self.num_ungated, dtype=self.dtype)],
238249
axis=0) + (
239250
tf.concat([(1. - self.residual_fraction) * tf.ones(
240-
self.gate_first_n), tf.ones(self.num_ungated)],
251+
self.gate_first_n, dtype=self.dtype),
252+
tf.ones(self.num_ungated, dtype=self.dtype)],
241253
axis=0)) * tf.math.sigmoid(x)
242254

243255
def _convex_update(self, weights_matrix):
244256
return tf.concat(
245257
[self.residual_fraction * tf.eye(num_rows=self.gate_first_n,
246-
num_columns=self.width),
247-
tf.zeros([self.num_ungated, self.width])],
258+
num_columns=self.width,
259+
dtype=self.dtype),
260+
tf.zeros([self.num_ungated, self.width], dtype=self.dtype)],
248261
axis=0) + tf.concat([(
249262
1. - self.residual_fraction) * tf.ones(
250-
self.gate_first_n), tf.ones(self.num_ungated)],
263+
self.gate_first_n, dtype=self.dtype),
264+
tf.ones(self.num_ungated, dtype=self.dtype)],
251265
axis=0) * weights_matrix
252266

253267
def _inverse_of_softplus(self, y, n=20):
254268
"""Inverse of the activation layer with softplus using Newton iteration."""
255-
x = tf.ones(y.shape)
269+
x = tf.ones_like(y, dtype=self.dtype)
256270
for _ in range(n):
257271
x = x - (tf.concat([(self.residual_fraction) * tf.ones(
258-
self.gate_first_n), tf.zeros(self.num_ungated)],
272+
self.gate_first_n, dtype=self.dtype),
273+
tf.zeros(self.num_ungated, dtype=self.dtype)],
259274
axis=0) * x + tf.concat(
260275
[(1. - self.residual_fraction) * tf.ones(
261-
self.gate_first_n), tf.ones(self.num_ungated)],
276+
self.gate_first_n, dtype=self.dtype),
277+
tf.ones(self.num_ungated, dtype=self.dtype)],
262278
axis=0) * tf.math.softplus(
263279
x) - y) / (
264280
self._derivative_of_softplus(x))
@@ -278,13 +294,14 @@ def _augmented_forward(self, x):
278294
# Log determinant term from the upper matrix. Note that the log determinant
279295
# of the lower matrix is zero.
280296

281-
fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum(
297+
fldj = tf.zeros(x.shape[:-1], dtype=self.dtype) + tf.reduce_sum(
282298
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
283-
self.gate_first_n), tf.zeros(self.num_ungated)],
299+
self.gate_first_n, dtype=self.dtype),
300+
tf.zeros(self.num_ungated, dtype=self.dtype)],
284301
axis=0) + (
285302
tf.concat([(1. - self.residual_fraction) * tf.ones(
286-
self.gate_first_n),
287-
tf.ones(self.num_ungated)],
303+
self.gate_first_n, dtype=self.dtype),
304+
tf.ones(self.num_ungated, dtype=self.dtype)],
288305
axis=0)) * tf.linalg.diag_part(
289306
self.upper_diagonal_weights_matrix)))
290307
x = x[tf.newaxis, ...]
@@ -294,17 +311,20 @@ def _augmented_forward(self, x):
294311
self._convex_update(self.upper_diagonal_weights_matrix),
295312
x, transpose_a=True)
296313
x += (tf.concat([(1. - self.residual_fraction) * tf.ones(
297-
self.gate_first_n), tf.ones(self.num_ungated)],
314+
self.gate_first_n, dtype=self.dtype),
315+
tf.ones(self.num_ungated, dtype=self.dtype)],
298316
axis=0) * self.bias)[tf.newaxis, ...]
299317

300318
if self.activation_fn:
301319
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x[0])),
302320
-1)
303321
x = tf.concat([(self.residual_fraction) * tf.ones(
304-
self.gate_first_n), tf.zeros(self.num_ungated)],
322+
self.gate_first_n, dtype=self.dtype),
323+
tf.zeros(self.num_ungated, dtype=self.dtype)],
305324
axis=0) * x + tf.concat(
306325
[(1. - self.residual_fraction) * tf.ones(
307-
self.gate_first_n), tf.ones(self.num_ungated)],
326+
self.gate_first_n, dtype=self.dtype),
327+
tf.ones(self.num_ungated, dtype=self.dtype)],
308328
axis=0) * tf.nn.softplus(x)
309329

310330
return tf.squeeze(x, 0), {'ildj': -fldj, 'fldj': fldj}
@@ -320,12 +340,14 @@ def _augmented_inverse(self, y):
320340
determinant of the jacobian.
321341
"""
322342

323-
ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum(
343+
ildj = tf.zeros(y.shape[:-1], dtype=self.dtype) - tf.reduce_sum(
324344
tf.math.log(tf.concat([(self.residual_fraction) * tf.ones(
325-
self.gate_first_n), tf.zeros(self.num_ungated)],
345+
self.gate_first_n, dtype=self.dtype),
346+
tf.zeros(self.num_ungated, dtype=self.dtype)],
326347
axis=0) + tf.concat(
327348
[(1. - self.residual_fraction) * tf.ones(
328-
self.gate_first_n), tf.ones(self.num_ungated)],
349+
self.gate_first_n, dtype=self.dtype),
350+
tf.ones(self.num_ungated, dtype=self.dtype)],
329351
axis=0) * tf.linalg.diag_part(
330352
self.upper_diagonal_weights_matrix)))
331353

@@ -337,7 +359,8 @@ def _augmented_inverse(self, y):
337359
y = y[..., tf.newaxis]
338360

339361
y = y - (tf.concat([(1. - self.residual_fraction) * tf.ones(
340-
self.gate_first_n), tf.ones(self.num_ungated)],
362+
self.gate_first_n, dtype=self.dtype),
363+
tf.ones(self.num_ungated, dtype=self.dtype)],
341364
axis=0) * self.bias)[..., tf.newaxis]
342365
y = tf.linalg.triangular_solve(
343366
self._convex_update(self.upper_diagonal_weights_matrix), y,
@@ -369,4 +392,4 @@ def _inverse_log_det_jacobian(self, y):
369392
if 'ildj' not in cached:
370393
_, attrs = self._augmented_inverse(y)
371394
cached.update(attrs)
372-
return cached['ildj']
395+
return cached['ildj']

0 commit comments

Comments
 (0)