25
25
from tensorflow_probability .python .bijectors import bijector
26
26
from tensorflow_probability .python .bijectors import blockwise
27
27
from tensorflow_probability .python .bijectors import chain
28
+ from tensorflow_probability .python .bijectors import composition
28
29
from tensorflow_probability .python .bijectors import exp
29
30
from tensorflow_probability .python .bijectors import identity
30
31
from tensorflow_probability .python .bijectors import invert
50
51
]
51
52
52
53
53
- class Glow (chain . Chain ):
54
+ class Glow (composition . Composition ):
54
55
r"""Implements the Glow Bijector from Kingma & Dhariwal (2018)[1].
55
56
56
57
Overview: `Glow` is a chain of bijectors which transforms a rank-1 tensor
@@ -319,6 +320,7 @@ def __init__(self,
319
320
name: Python `str`, name given to ops managed by this object.
320
321
Default value: `'glow'`.
321
322
"""
323
+ parameters = dict (locals ())
322
324
# Make sure that the input shape is fully defined.
323
325
if not tensorshape_util .is_fully_defined (output_shape ):
324
326
raise ValueError ('Shape must be fully defined.' )
@@ -441,10 +443,13 @@ def __init__(self,
441
443
]))
442
444
443
445
glow_chain = glow_chain [::- 1 ]
444
- # To finish off, we initialize the bijector with the chain we've built
445
- # This way, the rest of the model attributes are taken care of for us .
446
+ # To finish off, we build a bijector that chains the components together
447
+ # sequentially .
446
448
super (Glow , self ).__init__ (
447
- bijectors = glow_chain , validate_args = validate_args , name = name )
449
+ bijectors = chain .Chain (glow_chain , validate_args = validate_args ),
450
+ validate_args = validate_args ,
451
+ parameters = parameters ,
452
+ name = name )
448
453
449
454
@classmethod
450
455
def _parameter_properties (cls , dtype ):
@@ -491,7 +496,7 @@ def blockwise_splits(self):
491
496
return self ._blockwise_splits
492
497
493
498
494
- class ExitBijector (blockwise . Blockwise ):
499
+ class ExitBijector (composition . Composition ):
495
500
"""The spatial coupling bijector used in Glow.
496
501
497
502
This bijector consists of a blockwise bijector of a realNVP bijector. It is
@@ -524,7 +529,7 @@ def __init__(self,
524
529
tensor with event shape `input_shape`, and returns a tensor with shape
525
530
`input_shape`.
526
531
"""
527
-
532
+ parameters = dict ( locals ())
528
533
nleave , ngrab , npass = blockwise_splits
529
534
530
535
new_input_shape = input_shape [:- 1 ]+ (nleave ,)
@@ -551,7 +556,10 @@ def __init__(self,
551
556
bijector_fn = exit_bijector_fn )
552
557
553
558
super (ExitBijector , self ).__init__ (
554
- [shift_distribution , identity .Identity ()], [nleave + ngrab , npass ])
559
+ blockwise .Blockwise (
560
+ [shift_distribution , identity .Identity ()], [nleave + ngrab , npass ]),
561
+ parameters = parameters ,
562
+ name = 'exit_bijector' )
555
563
556
564
@staticmethod
557
565
def make_bijector_fn (layer , target_shape , scale_fn = tf .nn .sigmoid ):
@@ -601,7 +609,7 @@ def bijector_fn(inputs, ignored_input):
601
609
return bijector_fn
602
610
603
611
604
- class GlowBlock (chain . Chain ):
612
+ class GlowBlock (composition . Composition ):
605
613
"""Single block for a glow model.
606
614
607
615
This bijector contains `num_steps` steps of the flow, each consisting of an
@@ -613,7 +621,7 @@ class GlowBlock(chain.Chain):
613
621
614
622
def __init__ (self , input_shape , num_steps , coupling_bijector_fn ,
615
623
use_actnorm , seedstream ):
616
-
624
+ parameters = dict ( locals ())
617
625
rnvp_block = [identity .Identity ()]
618
626
this_nchan = input_shape [- 1 ]
619
627
@@ -646,7 +654,8 @@ def __init__(self, input_shape, num_steps, coupling_bijector_fn,
646
654
647
655
# Note that we reverse the list since Chain applies bijectors in reverse
648
656
# order.
649
- super (GlowBlock , self ).__init__ (rnvp_block [::- 1 ])
657
+ super (GlowBlock , self ).__init__ (
658
+ chain .Chain (rnvp_block [::- 1 ]), parameters = parameters , name = 'glow_block' )
650
659
651
660
@staticmethod
652
661
def make_bijector_fn (layer , scale_fn = tf .nn .sigmoid ):
@@ -809,7 +818,7 @@ def _init():
809
818
return tf .cond (self ._initialized , tf .no_op , _init )
810
819
811
820
812
- class Expand (chain . Chain ):
821
+ class Expand (composition . Composition ):
813
822
"""A bijector to transform channels into spatial pixels."""
814
823
815
824
def __init__ (self , input_shape , block_size = 2 , validate_args = False , name = None ):
@@ -827,8 +836,10 @@ def __init__(self, input_shape, block_size=2, validate_args=False, name=None):
827
836
event_shape_in = [h , w , c ],
828
837
event_shape_out = [h , w , c // n ** 2 , n , n ]),
829
838
]
830
- super (Expand , self ).__init__ (b , name = name or 'Expand' ,
831
- parameters = parameters )
839
+ super (Expand , self ).__init__ (
840
+ bijectors = chain .Chain (b , validate_args = validate_args ),
841
+ name = name or 'Expand' ,
842
+ parameters = parameters )
832
843
833
844
834
845
class GlowDefaultNetwork (tfk .Sequential ):
0 commit comments