20
20
from tensorflow_probability .python import bijectors as tfb
21
21
from tensorflow_probability .python import util
22
22
from tensorflow_probability .python .internal import cache_util
23
+ from tensorflow_probability .python .internal import dtype_util
24
+ from tensorflow_probability .python .internal import prefer_static as ps
23
25
from tensorflow_probability .python .internal import samplers
26
+ from tensorflow_probability .python .internal import tensor_util
24
27
25
28
26
29
def build_highway_flow_layer (width ,
@@ -178,26 +181,33 @@ def __init__(self, residual_fraction, activation_fn, bias,
178
181
"""
179
182
parameters = dict (locals ())
180
183
name = name or 'highway_flow'
184
+ dtype = dtype_util .common_dtype (
185
+ [residual_fraction , bias , upper_diagonal_weights_matrix ,
186
+ lower_diagonal_weights_matrix ], dtype_hint = tf .float32 )
181
187
with tf .name_scope (name ) as name :
182
- self ._width = tf .shape (bias )[- 1 ]
183
- self ._bias = bias
184
- self ._residual_fraction = residual_fraction
188
+ self ._width = ps .shape (bias )[- 1 ]
189
+ self ._bias = tensor_util .convert_nonref_to_tensor (bias , dtype = dtype ,
190
+ name = 'bias' )
191
+ self ._residual_fraction = tensor_util .convert_nonref_to_tensor (
192
+ residual_fraction , dtype = dtype , name = 'residual_fraction' )
185
193
# The upper matrix is still lower triangular, transpose is done in
186
194
# _inverse and _forwars metowds, within matvec.
187
- self ._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
188
- self ._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
195
+ self ._upper_diagonal_weights_matrix = tensor_util .convert_nonref_to_tensor (
196
+ upper_diagonal_weights_matrix , dtype = dtype ,
197
+ name = 'upper_diagonal_weights_matrix' )
198
+ self ._lower_diagonal_weights_matrix = tensor_util .convert_nonref_to_tensor (
199
+ lower_diagonal_weights_matrix , dtype = dtype ,
200
+ name = 'lower_diagonal_weights_matrix' )
189
201
self ._activation_fn = activation_fn
190
- if gate_first_n :
191
- self ._gate_first_n = gate_first_n if gate_first_n else self .width
192
-
193
- self ._num_ungated = self .self .width - self .gate_first_n
194
-
202
+ self ._gate_first_n = gate_first_n if gate_first_n else self .width
195
203
204
+ self ._num_ungated = self .width - self .gate_first_n
196
205
197
206
super (HighwayFlow , self ).__init__ (
198
207
validate_args = validate_args ,
199
208
forward_min_event_ndims = 1 ,
200
209
parameters = parameters ,
210
+ dtype = dtype ,
201
211
name = name )
202
212
203
213
@property
@@ -234,31 +244,37 @@ def num_ungated(self):
234
244
235
245
def _derivative_of_softplus (self , x ):
236
246
return tf .concat ([(self .residual_fraction ) * tf .ones (
237
- self .gate_first_n ), tf .zeros (self .num_ungated )],
247
+ self .gate_first_n , dtype = self .dtype ),
248
+ tf .zeros (self .num_ungated , dtype = self .dtype )],
238
249
axis = 0 ) + (
239
250
tf .concat ([(1. - self .residual_fraction ) * tf .ones (
240
- self .gate_first_n ), tf .ones (self .num_ungated )],
251
+ self .gate_first_n , dtype = self .dtype ),
252
+ tf .ones (self .num_ungated , dtype = self .dtype )],
241
253
axis = 0 )) * tf .math .sigmoid (x )
242
254
243
255
def _convex_update (self , weights_matrix ):
244
256
return tf .concat (
245
257
[self .residual_fraction * tf .eye (num_rows = self .gate_first_n ,
246
- num_columns = self .width ),
247
- tf .zeros ([self .num_ungated , self .width ])],
258
+ num_columns = self .width ,
259
+ dtype = self .dtype ),
260
+ tf .zeros ([self .num_ungated , self .width ], dtype = self .dtype )],
248
261
axis = 0 ) + tf .concat ([(
249
262
1. - self .residual_fraction ) * tf .ones (
250
- self .gate_first_n ), tf .ones (self .num_ungated )],
263
+ self .gate_first_n , dtype = self .dtype ),
264
+ tf .ones (self .num_ungated , dtype = self .dtype )],
251
265
axis = 0 ) * weights_matrix
252
266
253
267
def _inverse_of_softplus (self , y , n = 20 ):
254
268
"""Inverse of the activation layer with softplus using Newton iteration."""
255
- x = tf .ones ( y . shape )
269
+ x = tf .ones_like ( y , dtype = self . dtype )
256
270
for _ in range (n ):
257
271
x = x - (tf .concat ([(self .residual_fraction ) * tf .ones (
258
- self .gate_first_n ), tf .zeros (self .num_ungated )],
272
+ self .gate_first_n , dtype = self .dtype ),
273
+ tf .zeros (self .num_ungated , dtype = self .dtype )],
259
274
axis = 0 ) * x + tf .concat (
260
275
[(1. - self .residual_fraction ) * tf .ones (
261
- self .gate_first_n ), tf .ones (self .num_ungated )],
276
+ self .gate_first_n , dtype = self .dtype ),
277
+ tf .ones (self .num_ungated , dtype = self .dtype )],
262
278
axis = 0 ) * tf .math .softplus (
263
279
x ) - y ) / (
264
280
self ._derivative_of_softplus (x ))
@@ -278,13 +294,14 @@ def _augmented_forward(self, x):
278
294
# Log determinant term from the upper matrix. Note that the log determinant
279
295
# of the lower matrix is zero.
280
296
281
- fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
297
+ fldj = tf .zeros (x .shape [:- 1 ], dtype = self . dtype ) + tf .reduce_sum (
282
298
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
283
- self .gate_first_n ), tf .zeros (self .num_ungated )],
299
+ self .gate_first_n , dtype = self .dtype ),
300
+ tf .zeros (self .num_ungated , dtype = self .dtype )],
284
301
axis = 0 ) + (
285
302
tf .concat ([(1. - self .residual_fraction ) * tf .ones (
286
- self .gate_first_n ),
287
- tf .ones (self .num_ungated )],
303
+ self .gate_first_n , dtype = self . dtype ),
304
+ tf .ones (self .num_ungated , dtype = self . dtype )],
288
305
axis = 0 )) * tf .linalg .diag_part (
289
306
self .upper_diagonal_weights_matrix )))
290
307
x = x [tf .newaxis , ...]
@@ -294,17 +311,20 @@ def _augmented_forward(self, x):
294
311
self ._convex_update (self .upper_diagonal_weights_matrix ),
295
312
x , transpose_a = True )
296
313
x += (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
297
- self .gate_first_n ), tf .ones (self .num_ungated )],
314
+ self .gate_first_n , dtype = self .dtype ),
315
+ tf .ones (self .num_ungated , dtype = self .dtype )],
298
316
axis = 0 ) * self .bias )[tf .newaxis , ...]
299
317
300
318
if self .activation_fn :
301
319
fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_softplus (x [0 ])),
302
320
- 1 )
303
321
x = tf .concat ([(self .residual_fraction ) * tf .ones (
304
- self .gate_first_n ), tf .zeros (self .num_ungated )],
322
+ self .gate_first_n , dtype = self .dtype ),
323
+ tf .zeros (self .num_ungated , dtype = self .dtype )],
305
324
axis = 0 ) * x + tf .concat (
306
325
[(1. - self .residual_fraction ) * tf .ones (
307
- self .gate_first_n ), tf .ones (self .num_ungated )],
326
+ self .gate_first_n , dtype = self .dtype ),
327
+ tf .ones (self .num_ungated , dtype = self .dtype )],
308
328
axis = 0 ) * tf .nn .softplus (x )
309
329
310
330
return tf .squeeze (x , 0 ), {'ildj' : - fldj , 'fldj' : fldj }
@@ -320,12 +340,14 @@ def _augmented_inverse(self, y):
320
340
determinant of the jacobian.
321
341
"""
322
342
323
- ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
343
+ ildj = tf .zeros (y .shape [:- 1 ], dtype = self . dtype ) - tf .reduce_sum (
324
344
tf .math .log (tf .concat ([(self .residual_fraction ) * tf .ones (
325
- self .gate_first_n ), tf .zeros (self .num_ungated )],
345
+ self .gate_first_n , dtype = self .dtype ),
346
+ tf .zeros (self .num_ungated , dtype = self .dtype )],
326
347
axis = 0 ) + tf .concat (
327
348
[(1. - self .residual_fraction ) * tf .ones (
328
- self .gate_first_n ), tf .ones (self .num_ungated )],
349
+ self .gate_first_n , dtype = self .dtype ),
350
+ tf .ones (self .num_ungated , dtype = self .dtype )],
329
351
axis = 0 ) * tf .linalg .diag_part (
330
352
self .upper_diagonal_weights_matrix )))
331
353
@@ -337,7 +359,8 @@ def _augmented_inverse(self, y):
337
359
y = y [..., tf .newaxis ]
338
360
339
361
y = y - (tf .concat ([(1. - self .residual_fraction ) * tf .ones (
340
- self .gate_first_n ), tf .ones (self .num_ungated )],
362
+ self .gate_first_n , dtype = self .dtype ),
363
+ tf .ones (self .num_ungated , dtype = self .dtype )],
341
364
axis = 0 ) * self .bias )[..., tf .newaxis ]
342
365
y = tf .linalg .triangular_solve (
343
366
self ._convex_update (self .upper_diagonal_weights_matrix ), y ,
@@ -369,4 +392,4 @@ def _inverse_log_det_jacobian(self, y):
369
392
if 'ildj' not in cached :
370
393
_ , attrs = self ._augmented_inverse (y )
371
394
cached .update (attrs )
372
- return cached ['ildj' ]
395
+ return cached ['ildj' ]
0 commit comments