@@ -530,40 +530,8 @@ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
530
530
read_value = read_value )
531
531
532
532
533
- @type_spec .register ('tfp.util.TransformedVariableSpec' )
534
- class _TransformedVariableSpec (type_spec .BatchableTypeSpec ):
535
- """`tf.TypeSpec` for `tfp.util.TransformedVariable`."""
536
-
537
- __slots__ = ('_input_spec' , '_transform_or_spec' , '_dtype' , '_name' , '_specs' ,
538
- '_unique_id_params' , '_transform_is_composite' )
539
-
540
- def __init__ (self , input_spec , transform_or_spec , dtype , name ):
541
- """Initializes a new `_TransformedVariableSpec`.
542
-
543
- Args:
544
- input_spec: `tf.TypeSpec` instance describing the `TransformedVariable`s
545
- `pretransformed_input` attribute.
546
- transform_or_spec: The `bijector` passed to the `TransformedVariable`'s
547
- constructor, or `bijector._type_spec` if `bijector` is a
548
- `CompositeTensor`.
549
- dtype: `tf.DType`, `dtype` property of the `TransformedVariable`.
550
- name: `str`, name of the `TransformedVariable`.
551
- """
552
- self ._input_spec = input_spec
553
- self ._transform_or_spec = transform_or_spec
554
- self ._dtype = dtype
555
- self ._name = name
556
-
557
- self ._unique_id_params = {'dtype' : dtype }
558
- self ._transform_is_composite = isinstance (transform_or_spec , tf .TypeSpec )
559
-
560
- self ._specs = {'input_spec' : input_spec }
561
- if self ._transform_is_composite :
562
- self ._specs ['transform_or_spec' ] = transform_or_spec
563
-
564
- @property
565
- def value_type (self ):
566
- return TransformedVariable
533
+ class _DeferredTensorSpecBase (object ):
534
+ """Common methods for '_DeferredTensorSpec' and '_TransformedVariableSpec."""
567
535
568
536
@property
569
537
def name (self ):
@@ -637,58 +605,17 @@ def relax(value):
637
605
else :
638
606
return value
639
607
640
- transform_or_spec = self ._specs .pop (
608
+ specs = self ._specs .copy ()
609
+ transform_or_spec = specs .pop (
641
610
'transform_or_spec' , self .transform_or_spec )
642
611
return type (self )(
643
612
** tf .nest .map_structure (
644
613
relax ,
645
- dict (self . _specs ,
614
+ dict (specs ,
646
615
transform_or_spec = transform_or_spec ,
647
616
** self ._unique_id_params ,
648
617
name = self .name )))
649
618
650
- def _to_components (self , value ):
651
- """Encodes `value` as a nested structure of Tensor/CompositeTensor."""
652
- components = dict (pretransformed_input = value .pretransformed_input )
653
- if isinstance (value .bijector , tf .__internal__ .CompositeTensor ):
654
- components ['bijector' ] = value .bijector
655
- return components
656
-
657
- def _from_components (self , components ):
658
- """Reconstructs a value from a structure of Tensor/CompositeTensor."""
659
- bijector = components .pop ('bijector' , self .transform_or_spec )
660
- return TransformedVariable (
661
- ** components , initial_value = None , bijector = bijector ,
662
- dtype = self .dtype , name = self .name )
663
-
664
- @property
665
- def _component_specs (self ):
666
- """A nested structure of TypeSpecs for the DeferredTensor's components."""
667
- specs = dict (pretransformed_input = self ._input_spec )
668
- if self ._transform_is_composite :
669
- specs ['bijector' ] = self .transform_or_spec
670
- return specs
671
-
672
- def _batch (self , batch_size ):
673
- """Returns a TypeSpec representing a batch of DeferredTensors."""
674
- transform_or_spec = self ._specs .pop (
675
- 'transform_or_spec' , self .transform_or_spec )
676
- return type (self )(
677
- self ._get_batched_input_spec (batch_size ),
678
- transform_or_spec = transform_or_spec ,
679
- dtype = self .dtype ,
680
- name = self .name )
681
-
682
- def _unbatch (self ):
683
- """Returns a TypeSpec representing a single DeferredTensor."""
684
- transform_or_spec = self ._specs .pop (
685
- 'transform_or_spec' , self .transform_or_spec )
686
- return type (self )(
687
- self ._get_unbatched_input_spec (),
688
- transform_or_spec = transform_or_spec ,
689
- dtype = self .dtype ,
690
- name = self .name )
691
-
692
619
def _get_batched_input_spec (self , batch_size ):
693
620
"""Returns the batched `input_spec` for the given `batch_size`."""
694
621
if isinstance (self ._input_spec , type_spec .BatchableTypeSpec ):
@@ -759,3 +686,184 @@ def __ne__(self, other):
759
686
def __hash__ (self ):
760
687
return hash (self .__get_cmp_key ())
761
688
689
+
690
+ @type_spec .register ('tfp.util.DeferredTensorSpec' )
691
+ class _DeferredTensorSpec (_DeferredTensorSpecBase , type_spec .BatchableTypeSpec ):
692
+ """`tf.TypeSpec` for `tfp.util.DeferredTensor`."""
693
+
694
+ __slots__ = ('_input_spec' , '_transform_or_spec' , '_also_track_spec' ,
695
+ '_dtype' , '_shape' , '_name' , '_specs' , '_unique_id_params' ,
696
+ '_transform_is_composite' )
697
+
698
+ def __init__ (self , input_spec , transform_or_spec , dtype , shape , name ,
699
+ also_track_spec = None ):
700
+ """Initializes a new `_DeferredTensorSpec`.
701
+
702
+ Args:
703
+ input_spec: `tf.TypeSpec` instance describing the `DeferredTensor`s
704
+ `pretransformed_input` attribute.
705
+ transform_or_spec: The `transform_fn` passed to the `DeferredTensor`'s
706
+ constructor, or `transform_fn._type_spec` if `transform_fn` is a
707
+ `CompositeTensor`.
708
+ dtype: `tf.DType`, `dtype` property of the `DeferredTensor`.
709
+ shape: `tf.TensorShape`, `shape` property of the `DeferredTensor`.
710
+ name: `str`, name of the `DeferredTensor`.
711
+ also_track_spec: Python `list` of `VariableSpec` instances describing the
712
+ additional variables tracked by the `DeferredTensor`.
713
+ """
714
+ self ._input_spec = input_spec
715
+ self ._transform_or_spec = transform_or_spec
716
+ self ._also_track_spec = also_track_spec
717
+ self ._dtype = dtype
718
+ self ._shape = shape
719
+ self ._name = name
720
+
721
+ self ._transform_is_composite = isinstance (transform_or_spec , tf .TypeSpec )
722
+ self ._unique_id_params = {'dtype' : dtype , 'shape' : shape }
723
+
724
+ self ._specs = {'input_spec' : input_spec }
725
+ if self ._transform_is_composite :
726
+ self ._specs ['transform_or_spec' ] = transform_or_spec
727
+ if also_track_spec is not None :
728
+ self ._specs ['also_track_spec' ] = also_track_spec
729
+
730
+ @property
731
+ def value_type (self ):
732
+ return DeferredTensor
733
+
734
+ @property
735
+ def shape (self ):
736
+ return self ._shape
737
+
738
+ def _to_components (self , value ):
739
+ """Encodes `value` as a nested structure of Tensor/CompositeTensor."""
740
+ components = dict (pretransformed_input = value .pretransformed_input )
741
+ # pylint: disable=protected-access
742
+ if isinstance (value ._transform_fn , tf .__internal__ .CompositeTensor ):
743
+ components ['transform_fn' ] = value ._transform_fn
744
+ if value .also_track is not None :
745
+ components ['also_track' ] = tf .nest .flatten (
746
+ tf .nest .map_structure (
747
+ lambda x : x .variables if isinstance (x , tf .Module ) else x ,
748
+ value .also_track ))
749
+ return components
750
+
751
+ def _from_components (self , components ):
752
+ """Reconstructs a value from a structure of Tensor/CompositeTensor."""
753
+ transform_fn = components .pop ('transform_fn' , self .transform_or_spec )
754
+ return DeferredTensor (** components , transform_fn = transform_fn ,
755
+ dtype = self .dtype , shape = self .shape , name = self .name )
756
+
757
+ @property
758
+ def _component_specs (self ):
759
+ """A nested structure of TypeSpecs for the DeferredTensor's components."""
760
+ specs = dict (pretransformed_input = self ._input_spec )
761
+ if self ._transform_is_composite :
762
+ specs ['transform_fn' ] = self .transform_or_spec
763
+ if self ._also_track_spec is not None :
764
+ specs ['also_track' ] = self ._also_track_spec
765
+ return specs
766
+
767
+ def _batch (self , batch_size ):
768
+ """Returns a TypeSpec representing a batch of DeferredTensors."""
769
+ transform_or_spec = self ._specs .get (
770
+ 'transform_or_spec' , self .transform_or_spec )
771
+ return _DeferredTensorSpec (
772
+ self ._get_batched_input_spec (batch_size ),
773
+ transform_or_spec = transform_or_spec ,
774
+ dtype = self .dtype ,
775
+ shape = (None if self .shape is None
776
+ else tf .TensorShape ([batch_size ]).concatenate (self .shape )),
777
+ name = self .name ,
778
+ also_track_spec = self ._also_track_spec )
779
+
780
+ def _unbatch (self ):
781
+ """Returns a TypeSpec representing a single DeferredTensor."""
782
+ transform_or_spec = self ._specs .get (
783
+ 'transform_or_spec' , self .transform_or_spec )
784
+ return _DeferredTensorSpec (
785
+ self ._get_unbatched_input_spec (),
786
+ transform_or_spec = transform_or_spec ,
787
+ dtype = self .dtype ,
788
+ shape = (None if self .shape is None else self .shape [1 :]),
789
+ name = self .name ,
790
+ also_track_spec = self ._also_track_spec )
791
+
792
+
793
+ @type_spec .register ('tfp.util.TransformedVariableSpec' )
794
+ class _TransformedVariableSpec (
795
+ _DeferredTensorSpecBase , type_spec .BatchableTypeSpec ):
796
+ """`tf.TypeSpec` for `tfp.util.TransformedVariable`."""
797
+
798
+ __slots__ = ('_input_spec' , '_transform_or_spec' , '_dtype' , '_name' , '_specs' ,
799
+ '_unique_id_params' , '_transform_is_composite' )
800
+
801
+ def __init__ (self , input_spec , transform_or_spec , dtype , name ):
802
+ """Initializes a new `_TransformedVariableSpec`.
803
+
804
+ Args:
805
+ input_spec: `tf.TypeSpec` instance describing the `TransformedVariable`s
806
+ `pretransformed_input` attribute.
807
+ transform_or_spec: The `bijector` passed to the `TransformedVariable`'s
808
+ constructor, or `bijector._type_spec` if `bijector` is a
809
+ `CompositeTensor`.
810
+ dtype: `tf.DType`, `dtype` property of the `TransformedVariable`.
811
+ name: `str`, name of the `TransformedVariable`.
812
+ """
813
+ self ._input_spec = input_spec
814
+ self ._transform_or_spec = transform_or_spec
815
+ self ._dtype = dtype
816
+ self ._name = name
817
+
818
+ self ._unique_id_params = {'dtype' : dtype }
819
+ self ._transform_is_composite = isinstance (transform_or_spec , tf .TypeSpec )
820
+
821
+ self ._specs = {'input_spec' : input_spec }
822
+ if self ._transform_is_composite :
823
+ self ._specs ['transform_or_spec' ] = transform_or_spec
824
+
825
+ @property
826
+ def value_type (self ):
827
+ return TransformedVariable
828
+
829
+ def _to_components (self , value ):
830
+ """Encodes `value` as a nested structure of Tensor/CompositeTensor."""
831
+ components = dict (pretransformed_input = value .pretransformed_input )
832
+ if isinstance (value .bijector , tf .__internal__ .CompositeTensor ):
833
+ components ['bijector' ] = value .bijector
834
+ return components
835
+
836
+ def _from_components (self , components ):
837
+ """Reconstructs a value from a structure of Tensor/CompositeTensor."""
838
+ bijector = components .pop ('bijector' , self .transform_or_spec )
839
+ return TransformedVariable (
840
+ ** components , initial_value = None , bijector = bijector ,
841
+ dtype = self .dtype , name = self .name )
842
+
843
+ @property
844
+ def _component_specs (self ):
845
+ """A structure of TypeSpecs for the TransformedVariable's components."""
846
+ specs = dict (pretransformed_input = self ._input_spec )
847
+ if self ._transform_is_composite :
848
+ specs ['bijector' ] = self .transform_or_spec
849
+ return specs
850
+
851
+ def _batch (self , batch_size ):
852
+ """Returns a TypeSpec representing a batch of TransformedVariable."""
853
+ transform_or_spec = self ._specs .get (
854
+ 'transform_or_spec' , self .transform_or_spec )
855
+ return _TransformedVariableSpec (
856
+ self ._get_batched_input_spec (batch_size ),
857
+ transform_or_spec = transform_or_spec ,
858
+ dtype = self .dtype ,
859
+ name = self .name )
860
+
861
+ def _unbatch (self ):
862
+ """Returns a TypeSpec representing a single TransformedVariable."""
863
+ transform_or_spec = self ._specs .get (
864
+ 'transform_or_spec' , self .transform_or_spec )
865
+ return _TransformedVariableSpec (
866
+ self ._get_unbatched_input_spec (),
867
+ transform_or_spec = transform_or_spec ,
868
+ dtype = self .dtype ,
869
+ name = self .name )
0 commit comments