22
22
from tensorflow_probability .python .internal import cache_util
23
23
from tensorflow_probability .python .internal import samplers
24
24
25
-
26
25
def build_highway_flow_layer (width ,
27
26
residual_fraction_initial_value = 0.5 ,
28
27
activation_fn = False ,
@@ -363,3 +362,243 @@ def _inverse_log_det_jacobian(self, y):
363
362
_ , attrs = self ._augmented_inverse (y )
364
363
cached .update (attrs )
365
364
return cached ['ildj' ]
365
+ == == == =
366
+ def build_highway_flow_layer (width , residual_fraction_initial_value = 0.5 ,
367
+ activation_fn = False , seed = None ):
368
+ # TODO: add control that residual_fraction_initial_value is between 0 and 1
369
+ residual_fraction_initial_value = tf .convert_to_tensor (
370
+ residual_fraction_initial_value ,
371
+ dtype_hint = tf .float32 ,
372
+ name = 'residual_fraction_initial_value' )
373
+ dtype = residual_fraction_initial_value .dtype
374
+
375
+ bias_seed , upper_seed , lower_seed , diagonal_seed = samplers .split_seed (seed ,
376
+ n = 4 )
377
+ return HighwayFlow (
378
+ residual_fraction = util .TransformedVariable (
379
+ initial_value = residual_fraction_initial_value ,
380
+ bijector = tfb .Sigmoid (),
381
+ dtype = dtype ),
382
+ activation_fn = activation_fn ,
383
+ bias = tf .Variable (
384
+ samplers .normal ((width ,), mean = 0. , stddev = 0.01 , seed = bias_seed ),
385
+ dtype = dtype ),
386
+ upper_diagonal_weights_matrix = util .TransformedVariable (
387
+ initial_value = tf .experimental .numpy .tril (
388
+ samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
389
+ seed = upper_seed ),
390
+ k = - 1 ) + tf .linalg .diag (
391
+ samplers .uniform ((width ,), minval = 0. , maxval = 1. ,
392
+ seed = diagonal_seed )),
393
+ bijector = tfb .FillScaleTriL (diag_bijector = tfb .Softplus (),
394
+ diag_shift = None ),
395
+ dtype = dtype ),
396
+ lower_diagonal_weights_matrix = util .TransformedVariable (
397
+ initial_value = samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
398
+ seed = lower_seed ),
399
+ bijector = tfb .Chain (
400
+ [tfb .TransformDiagonal (diag_bijector = tfb .Shift (1. )),
401
+ tfb .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
402
+ tfb .FillTriangular ()]),
403
+ dtype = dtype )
404
+ )
405
+
406
+
407
+ class HighwayFlow (tfb .Bijector ):
408
+ """Implements an Highway Flow bijector [1], which interpolates the input
409
+ `X` with the transformations at each step of the bjiector.
410
+ The Highway Flow can be used as building block for a Cascading flow [1]
411
+ or as a generic normalizing flow.
412
+
413
+ The transformation consists in a convex update between the input `X` and a
414
+ linear transformation of `X` followed by activation with the form `g(A @
415
+ X + b)`, where `g(.)` is a differentiable non-decreasing activation
416
+ function, and `A` and `b` are trainable weights.
417
+
418
+ The convex update is regulated by a trainable residual fraction `l`
419
+ constrained between 0 and 1, and can be
420
+ formalized as:
421
+ `Y = l * X + (1 - l) * g(A @ X + b)`.
422
+
423
+ To make this transformation invertible, the bijector is split in three
424
+ convex updates:
425
+ - `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones
426
+ on the diagonal;
427
+ - `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix
428
+ with positive diagonal;
429
+ - `Y = l * Y2 + (1 - l) * g(Y2)`
430
+
431
+ The function `build_highway_flow_layer` helps initializing the bijector
432
+ with the variables respecting the various constraints.
433
+
434
+ For more details on Highway Flow and Cascading Flows see [1].
435
+
436
+ #### Usage example:
437
+ ```python
438
+ tfd = tfp.distributions
439
+ tfb = tfp.bijectors
440
+
441
+ dim = 4 # last input dimension
442
+
443
+ bijector = build_highway_flow_layer(dim, activation_fn=True)
444
+ y = bijector.forward(x) # forward mapping
445
+ x = bijector.inverse(y) # inverse mapping
446
+ base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution
447
+ transformed_distribution = tfd.TransformedDistribution(base, bijector)
448
+ ```
449
+
450
+ #### References
451
+
452
+ [1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
453
+ "Automatic variational inference with
454
+ cascading flows." arXiv preprint arXiv:2102.04801 (2021).
455
+ """
456
+
457
+ # HighWay Flow simultaneously computes `forward` and `fldj`
458
+ # (and `inverse`/`ildj`), so we override the bijector cache to update the
459
+ # LDJ entries of attrs on forward/inverse inverse calls (instead of
460
+ # updating them only when the LDJ methods themselves are called).
461
+
462
+ _cache = cache_util .BijectorCacheWithGreedyAttrs (
463
+ forward_name = '_augmented_forward' ,
464
+ inverse_name = '_augmented_inverse' )
465
+
466
+ def __init__ (self , residual_fraction , activation_fn , bias ,
467
+ upper_diagonal_weights_matrix ,
468
+ lower_diagonal_weights_matrix , validate_args = False ,
469
+ name = 'highway_flow' ):
470
+ '''
471
+ Args:
472
+ residual_fraction: scalar `Tensor` used for the convex update,
473
+ must be
474
+ between 0 and 1
475
+ activation_fn: bool to decide whether to use softplus (True)
476
+ activation or no activation (False)
477
+ bias: bias vector
478
+ upper_diagonal_weights_matrix: Lower diagional matrix of size
479
+ (width, width) with positive diagonal
480
+ (is transposed to Upper diagonal within the bijector)
481
+ lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
482
+ the main diagional.
483
+ '''
484
+ parameters = dict (locals ())
485
+ with tf .name_scope (name ) as name :
486
+ self ._width = tf .shape (bias )[- 1 ]
487
+ self ._bias = bias
488
+ self ._residual_fraction = residual_fraction
489
+ # The upper matrix is still lower triangular, transpose is done in
490
+ # _inverse and _forwars metowds, within matvec.
491
+ self ._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
492
+ self ._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
493
+ self ._activation_fn = activation_fn
494
+
495
+ super (HighwayFlow , self ).__init__ (
496
+ validate_args = validate_args ,
497
+ forward_min_event_ndims = 1 ,
498
+ parameters = parameters ,
499
+ name = name )
500
+
501
+ @property
502
+ def bias (self ):
503
+ return self ._bias
504
+
505
+ @property
506
+ def width (self ):
507
+ return self ._width
508
+
509
+ @property
510
+ def residual_fraction (self ):
511
+ return self ._residual_fraction
512
+
513
+ @property
514
+ def upper_diagonal_weights_matrix (self ):
515
+ return self ._upper_diagonal_weights_matrix
516
+
517
+ @property
518
+ def lower_diagonal_weights_matrix (self ):
519
+ return self ._lower_diagonal_weights_matrix
520
+
521
+ @property
522
+ def activation_fn (self ):
523
+ return self ._activation_fn
524
+
525
+ def _derivative_of_sigmoid (self , x ):
526
+ return self .residual_fraction + (
527
+ 1. - self .residual_fraction ) * tf .math .sigmoid (x )
528
+
529
+ def _convex_update (self , weights_matrix ):
530
+ return self .residual_fraction * tf .eye (self .width ) + (
531
+ 1. - self .residual_fraction ) * weights_matrix
532
+
533
+ def _inverse_of_sigmoid (self , y , N = 20 ):
534
+ # Inverse of the activation layer with softplus using Newton iteration.
535
+ x = tf .ones (y .shape )
536
+ for _ in range (N ):
537
+ x = x - (self .residual_fraction * x + (
538
+ 1. - self .residual_fraction ) * tf .math .softplus (
539
+ x ) - y ) / (
540
+ self ._derivative_of_sigmoid (x ))
541
+ return x
542
+
543
+ def _augmented_forward (self , x ):
544
+ # Log determinant term from the upper matrix. Note that the log determinant
545
+ # of the lower matrix is zero.
546
+ fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
547
+ tf .math .log (self .residual_fraction + (
548
+ 1. - self .residual_fraction ) * tf .linalg .diag_part (
549
+ self .upper_diagonal_weights_matrix )))
550
+ x = tf .linalg .matvec (
551
+ self ._convex_update (self .lower_diagonal_weights_matrix ), x )
552
+ x = tf .linalg .matvec (tf .transpose (
553
+ self ._convex_update (self .upper_diagonal_weights_matrix )),
554
+ x ) + (
555
+ 1 - self .residual_fraction ) * self .bias
556
+ if self .activation_fn :
557
+ fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (x )),
558
+ - 1 )
559
+ x = self .residual_fraction * x + (
560
+ 1. - self .residual_fraction ) * self .activation_fn (x )
561
+ return x , {'ildj' : - fldj , 'fldj' : fldj }
562
+
563
+ def _augmented_inverse (self , y ):
564
+ ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
565
+ tf .math .log (self .residual_fraction + (
566
+ 1. - self .residual_fraction ) * tf .linalg .diag_part (
567
+ self .upper_diagonal_weights_matrix )))
568
+ if self .activation_fn :
569
+ y = self ._inverse_of_sigmoid (y )
570
+ ildj -= tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (y )),
571
+ - 1 )
572
+
573
+ y = tf .linalg .triangular_solve (tf .transpose (
574
+ self ._convex_update (self .upper_diagonal_weights_matrix )),
575
+ tf .linalg .matrix_transpose (y - (
576
+ 1 - self .residual_fraction ) * self .bias ),
577
+ lower = False )
578
+ y = tf .linalg .triangular_solve (
579
+ self ._convex_update (self .lower_diagonal_weights_matrix ), y )
580
+ return tf .linalg .matrix_transpose (y ), {'ildj' : ildj , 'fldj' : - ildj }
581
+
582
+ def _forward (self , x ):
583
+ y , _ = self ._augmented_forward (x )
584
+ return y
585
+
586
+ def _inverse (self , y ):
587
+ x , _ = self ._augmented_inverse (y )
588
+ return x
589
+
590
+ def _forward_log_det_jacobian (self , x ):
591
+ cached = self ._cache .forward_attributes (x )
592
+ # If LDJ isn't in the cache, call forward once.
593
+ if 'fldj' not in cached :
594
+ _ , attrs = self ._augmented_forward (x )
595
+ cached .update (attrs )
596
+ return cached ['fldj' ]
597
+
598
+ def _inverse_log_det_jacobian (self , y ):
599
+ cached = self ._cache .inverse_attributes (y )
600
+ # If LDJ isn't in the cache, call inverse once.
601
+ if 'ildj' not in cached :
602
+ _ , attrs = self ._augmented_inverse (y )
603
+ cached .update (attrs )
604
+ return cached ['ildj' ]
0 commit comments