26
26
def build_highway_flow_layer (width ,
27
27
residual_fraction_initial_value = 0.5 ,
28
28
activation_fn = False ,
29
- gate_first_n = - 1 ,
29
+ gate_first_n = None ,
30
30
seed = None ):
31
31
"""Builds HighwayFlow making sure that all the requirements are satisfied.
32
32
@@ -49,8 +49,6 @@ def build_highway_flow_layer(width,
49
49
`bias` is a randomly initialized vector of size `width`
50
50
"""
51
51
52
- if gate_first_n == - 1 :
53
- gate_first_n = width
54
52
# TODO: add control that residual_fraction_initial_value is between 0 and 1
55
53
residual_fraction_initial_value = tf .convert_to_tensor (
56
54
residual_fraction_initial_value ,
@@ -189,7 +187,10 @@ def __init__(self, residual_fraction, activation_fn, bias,
189
187
self ._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
190
188
self ._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
191
189
self ._activation_fn = activation_fn
192
- self ._gate_first_n = gate_first_n
190
+ if gate_first_n :
191
+ self ._gate_first_n = gate_first_n if gate_first_n else self .width
192
+
193
+
193
194
194
195
super (HighwayFlow , self ).__init__ (
195
196
validate_args = validate_args ,
@@ -362,244 +363,4 @@ def _inverse_log_det_jacobian(self, y):
362
363
if 'ildj' not in cached :
363
364
_ , attrs = self ._augmented_inverse (y )
364
365
cached .update (attrs )
365
- return cached ['ildj' ]
366
-
367
- def build_highway_flow_layer (width , residual_fraction_initial_value = 0.5 ,
368
- activation_fn = False , seed = None ):
369
- # TODO: add control that residual_fraction_initial_value is between 0 and 1
370
- residual_fraction_initial_value = tf .convert_to_tensor (
371
- residual_fraction_initial_value ,
372
- dtype_hint = tf .float32 ,
373
- name = 'residual_fraction_initial_value' )
374
- dtype = residual_fraction_initial_value .dtype
375
-
376
- bias_seed , upper_seed , lower_seed , diagonal_seed = samplers .split_seed (seed ,
377
- n = 4 )
378
- return HighwayFlow (
379
- residual_fraction = util .TransformedVariable (
380
- initial_value = residual_fraction_initial_value ,
381
- bijector = tfb .Sigmoid (),
382
- dtype = dtype ),
383
- activation_fn = activation_fn ,
384
- bias = tf .Variable (
385
- samplers .normal ((width ,), mean = 0. , stddev = 0.01 , seed = bias_seed ),
386
- dtype = dtype ),
387
- upper_diagonal_weights_matrix = util .TransformedVariable (
388
- initial_value = tf .experimental .numpy .tril (
389
- samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
390
- seed = upper_seed ),
391
- k = - 1 ) + tf .linalg .diag (
392
- samplers .uniform ((width ,), minval = 0. , maxval = 1. ,
393
- seed = diagonal_seed )),
394
- bijector = tfb .FillScaleTriL (diag_bijector = tfb .Softplus (),
395
- diag_shift = None ),
396
- dtype = dtype ),
397
- lower_diagonal_weights_matrix = util .TransformedVariable (
398
- initial_value = samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
399
- seed = lower_seed ),
400
- bijector = tfb .Chain (
401
- [tfb .TransformDiagonal (diag_bijector = tfb .Shift (1. )),
402
- tfb .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
403
- tfb .FillTriangular ()]),
404
- dtype = dtype )
405
- )
406
-
407
-
408
- class HighwayFlow (tfb .Bijector ):
409
- """Implements an Highway Flow bijector [1], which interpolates the input
410
- `X` with the transformations at each step of the bjiector.
411
- The Highway Flow can be used as building block for a Cascading flow [1]
412
- or as a generic normalizing flow.
413
-
414
- The transformation consists in a convex update between the input `X` and a
415
- linear transformation of `X` followed by activation with the form `g(A @
416
- X + b)`, where `g(.)` is a differentiable non-decreasing activation
417
- function, and `A` and `b` are trainable weights.
418
-
419
- The convex update is regulated by a trainable residual fraction `l`
420
- constrained between 0 and 1, and can be
421
- formalized as:
422
- `Y = l * X + (1 - l) * g(A @ X + b)`.
423
-
424
- To make this transformation invertible, the bijector is split in three
425
- convex updates:
426
- - `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones
427
- on the diagonal;
428
- - `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix
429
- with positive diagonal;
430
- - `Y = l * Y2 + (1 - l) * g(Y2)`
431
-
432
- The function `build_highway_flow_layer` helps initializing the bijector
433
- with the variables respecting the various constraints.
434
-
435
- For more details on Highway Flow and Cascading Flows see [1].
436
-
437
- #### Usage example:
438
- ```python
439
- tfd = tfp.distributions
440
- tfb = tfp.bijectors
441
-
442
- dim = 4 # last input dimension
443
-
444
- bijector = build_highway_flow_layer(dim, activation_fn=True)
445
- y = bijector.forward(x) # forward mapping
446
- x = bijector.inverse(y) # inverse mapping
447
- base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution
448
- transformed_distribution = tfd.TransformedDistribution(base, bijector)
449
- ```
450
-
451
- #### References
452
-
453
- [1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
454
- "Automatic variational inference with
455
- cascading flows." arXiv preprint arXiv:2102.04801 (2021).
456
- """
457
-
458
- # HighWay Flow simultaneously computes `forward` and `fldj`
459
- # (and `inverse`/`ildj`), so we override the bijector cache to update the
460
- # LDJ entries of attrs on forward/inverse inverse calls (instead of
461
- # updating them only when the LDJ methods themselves are called).
462
-
463
- _cache = cache_util .BijectorCacheWithGreedyAttrs (
464
- forward_name = '_augmented_forward' ,
465
- inverse_name = '_augmented_inverse' )
466
-
467
- def __init__ (self , residual_fraction , activation_fn , bias ,
468
- upper_diagonal_weights_matrix ,
469
- lower_diagonal_weights_matrix , validate_args = False ,
470
- name = 'highway_flow' ):
471
- '''
472
- Args:
473
- residual_fraction: scalar `Tensor` used for the convex update,
474
- must be
475
- between 0 and 1
476
- activation_fn: bool to decide whether to use softplus (True)
477
- activation or no activation (False)
478
- bias: bias vector
479
- upper_diagonal_weights_matrix: Lower diagional matrix of size
480
- (width, width) with positive diagonal
481
- (is transposed to Upper diagonal within the bijector)
482
- lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
483
- the main diagional.
484
- '''
485
- parameters = dict (locals ())
486
- with tf .name_scope (name ) as name :
487
- self ._width = tf .shape (bias )[- 1 ]
488
- self ._bias = bias
489
- self ._residual_fraction = residual_fraction
490
- # The upper matrix is still lower triangular, transpose is done in
491
- # _inverse and _forwars metowds, within matvec.
492
- self ._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
493
- self ._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
494
- self ._activation_fn = activation_fn
495
-
496
- super (HighwayFlow , self ).__init__ (
497
- validate_args = validate_args ,
498
- forward_min_event_ndims = 1 ,
499
- parameters = parameters ,
500
- name = name )
501
-
502
- @property
503
- def bias (self ):
504
- return self ._bias
505
-
506
- @property
507
- def width (self ):
508
- return self ._width
509
-
510
- @property
511
- def residual_fraction (self ):
512
- return self ._residual_fraction
513
-
514
- @property
515
- def upper_diagonal_weights_matrix (self ):
516
- return self ._upper_diagonal_weights_matrix
517
-
518
- @property
519
- def lower_diagonal_weights_matrix (self ):
520
- return self ._lower_diagonal_weights_matrix
521
-
522
- @property
523
- def activation_fn (self ):
524
- return self ._activation_fn
525
-
526
- def _derivative_of_sigmoid (self , x ):
527
- return self .residual_fraction + (
528
- 1. - self .residual_fraction ) * tf .math .sigmoid (x )
529
-
530
- def _convex_update (self , weights_matrix ):
531
- return self .residual_fraction * tf .eye (self .width ) + (
532
- 1. - self .residual_fraction ) * weights_matrix
533
-
534
- def _inverse_of_sigmoid (self , y , N = 20 ):
535
- # Inverse of the activation layer with softplus using Newton iteration.
536
- x = tf .ones (y .shape )
537
- for _ in range (N ):
538
- x = x - (self .residual_fraction * x + (
539
- 1. - self .residual_fraction ) * tf .math .softplus (
540
- x ) - y ) / (
541
- self ._derivative_of_sigmoid (x ))
542
- return x
543
-
544
- def _augmented_forward (self , x ):
545
- # Log determinant term from the upper matrix. Note that the log determinant
546
- # of the lower matrix is zero.
547
- fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
548
- tf .math .log (self .residual_fraction + (
549
- 1. - self .residual_fraction ) * tf .linalg .diag_part (
550
- self .upper_diagonal_weights_matrix )))
551
- x = tf .linalg .matvec (
552
- self ._convex_update (self .lower_diagonal_weights_matrix ), x )
553
- x = tf .linalg .matvec (tf .transpose (
554
- self ._convex_update (self .upper_diagonal_weights_matrix )),
555
- x ) + (
556
- 1 - self .residual_fraction ) * self .bias
557
- if self .activation_fn :
558
- fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (x )),
559
- - 1 )
560
- x = self .residual_fraction * x + (
561
- 1. - self .residual_fraction ) * self .activation_fn (x )
562
- return x , {'ildj' : - fldj , 'fldj' : fldj }
563
-
564
- def _augmented_inverse (self , y ):
565
- ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
566
- tf .math .log (self .residual_fraction + (
567
- 1. - self .residual_fraction ) * tf .linalg .diag_part (
568
- self .upper_diagonal_weights_matrix )))
569
- if self .activation_fn :
570
- y = self ._inverse_of_sigmoid (y )
571
- ildj -= tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (y )),
572
- - 1 )
573
-
574
- y = tf .linalg .triangular_solve (tf .transpose (
575
- self ._convex_update (self .upper_diagonal_weights_matrix )),
576
- tf .linalg .matrix_transpose (y - (
577
- 1 - self .residual_fraction ) * self .bias ),
578
- lower = False )
579
- y = tf .linalg .triangular_solve (
580
- self ._convex_update (self .lower_diagonal_weights_matrix ), y )
581
- return tf .linalg .matrix_transpose (y ), {'ildj' : ildj , 'fldj' : - ildj }
582
-
583
- def _forward (self , x ):
584
- y , _ = self ._augmented_forward (x )
585
- return y
586
-
587
- def _inverse (self , y ):
588
- x , _ = self ._augmented_inverse (y )
589
- return x
590
-
591
- def _forward_log_det_jacobian (self , x ):
592
- cached = self ._cache .forward_attributes (x )
593
- # If LDJ isn't in the cache, call forward once.
594
- if 'fldj' not in cached :
595
- _ , attrs = self ._augmented_forward (x )
596
- cached .update (attrs )
597
- return cached ['fldj' ]
598
-
599
- def _inverse_log_det_jacobian (self , y ):
600
- cached = self ._cache .inverse_attributes (y )
601
- # If LDJ isn't in the cache, call inverse once.
602
- if 'ildj' not in cached :
603
- _ , attrs = self ._augmented_inverse (y )
604
- cached .update (attrs )
605
- return cached ['ildj' ]
366
+ return cached ['ildj' ]
0 commit comments