18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
- import copy
22
21
import six
23
22
24
23
import tensorflow .compat .v2 as tf
25
24
26
25
from tensorflow_probability .python .distributions import distribution as tfd
26
+ from tensorflow_probability .python .distributions import kullback_leibler
27
27
from tensorflow_probability .python .internal import nest_util
28
28
from tensorflow_probability .python .internal import parameter_properties
29
29
from tensorflow_probability .python .util .deferred_tensor import TensorMetaClass
30
30
from tensorflow .python .framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
31
+ from tensorflow .python .training .tracking import data_structures # pylint: disable=g-direct-tensorflow-import
31
32
32
33
33
34
__all__ = [] # We intend nothing public.
34
35
36
+ _NOT_FOUND = object ()
37
+
35
38
36
39
# Define mixin type because Distribution already has its own metaclass.
37
40
class _DistributionAndTensorCoercibleMeta (type (tfd .Distribution ),
@@ -43,43 +46,123 @@ class _DistributionAndTensorCoercibleMeta(type(tfd.Distribution),
43
46
class _TensorCoercible (tfd .Distribution ):
44
47
"""Docstring."""
45
48
46
- registered_class_list = {}
47
-
48
- def __new__ (cls , distribution , convert_to_tensor_fn = tfd .Distribution .sample ):
49
- if isinstance (distribution , cls ):
50
- return distribution
51
- if not isinstance (distribution , tfd .Distribution ):
52
- raise TypeError ('`distribution` argument must be a '
53
- '`tfd.Distribution` instance; '
54
- 'saw "{}" of type "{}".' .format (
55
- distribution , type (distribution )))
56
- self = copy .copy (distribution )
57
- distcls = distribution .__class__
58
- self_class = _TensorCoercible .registered_class_list .get (distcls )
59
- if not self_class :
60
- self_class = type (distcls .__name__ , (cls , distcls ), {})
61
- _TensorCoercible .registered_class_list [distcls ] = self_class
62
- self .__class__ = self_class
63
- return self
64
-
65
49
def __init__ (self ,
66
50
distribution ,
67
51
convert_to_tensor_fn = tfd .Distribution .sample ):
68
52
self ._concrete_value = None # pylint: disable=protected-access
69
53
self ._convert_to_tensor_fn = convert_to_tensor_fn # pylint: disable=protected-access
54
+ self .tensor_distribution = distribution
55
+ super (_TensorCoercible , self ).__init__ (
56
+ dtype = distribution .dtype ,
57
+ reparameterization_type = distribution .reparameterization_type ,
58
+ validate_args = distribution .validate_args ,
59
+ allow_nan_stats = distribution .allow_nan_stats ,
60
+ parameters = distribution .parameters )
61
+
62
+ def __setattr__ (self , name , value ):
63
+ """Support self.foo = trackable syntax.
64
+
65
+ Redefined from `tensorflow/python/training/tracking/tracking.py` to avoid
66
+ calling `getattr`, which causes an infinite loop.
67
+
68
+ Args:
69
+ name: str, name of the attribute to be set.
70
+ value: value to be set.
71
+ """
72
+ if vars (self ).get (name , _NOT_FOUND ) is value :
73
+ return
74
+
75
+ if vars (self ).get ('_self_setattr_tracking' , True ):
76
+ value = data_structures .sticky_attribute_assignment (
77
+ trackable = self , value = value , name = name )
78
+ object .__setattr__ (self , name , value )
79
+
80
+ def __getattr__ (self , name ):
81
+ # If the attribute is set in the _TensorCoercible object, return it. This
82
+ # ensures that direct calls to `getattr` behave as expected.
83
+ if name in vars (self ):
84
+ return vars (self )[name ]
85
+ # Look for the attribute in `tensor_distribution`, unless it's a `_tracking`
86
+ # attribute accessed directly by `getattr` in the `Trackable` base class, in
87
+ # which case the default passed to `getattr` should be returned.
88
+ if 'tensor_distribution' in vars (self ) and '_tracking' not in name :
89
+ return getattr (vars (self )['tensor_distribution' ], name )
90
+ # Otherwise invoke `__getattribute__`, which will return the default passed
91
+ # to `getattr` if the attribute was not found.
92
+ return self .__getattribute__ (name )
70
93
71
94
@classmethod
72
95
def _parameter_properties (cls , dtype , num_classes = None ):
73
96
return dict (distribution = parameter_properties .BatchedComponentProperties ())
74
97
98
+ # pylint: disable=protected-access
75
99
def _batch_shape_tensor (self , ** parameter_kwargs ):
76
- # Any parameter kwargs are for the inner distribution, so pass them
77
- # to its `_batch_shape_tensor` method instead of handling them directly.
78
- return self .parameters ['distribution' ]._batch_shape_tensor ( # pylint: disable=protected-access
79
- ** parameter_kwargs )
100
+ return self .tensor_distribution ._batch_shape_tensor (** parameter_kwargs )
101
+
102
+ def _batch_shape (self ):
103
+ return self .tensor_distribution ._batch_shape ()
104
+
105
+ def _event_shape_tensor (self ):
106
+ return self .tensor_distribution ._event_shape_tensor ()
107
+
108
+ def _event_shape (self ):
109
+ return self .tensor_distribution ._event_shape ()
110
+
111
+ def sample (self , sample_shape = (), seed = None , name = 'sample' , ** kwargs ):
112
+ return self .tensor_distribution .sample (
113
+ sample_shape = sample_shape , seed = seed , name = name , ** kwargs )
114
+
115
+ def _log_prob (self , value , ** kwargs ):
116
+ return self .tensor_distribution ._log_prob (value , ** kwargs )
117
+
118
+ def _prob (self , value , ** kwargs ):
119
+ return self .tensor_distribution ._prob (value , ** kwargs )
120
+
121
+ def _log_cdf (self , value , ** kwargs ):
122
+ return self .tensor_distribution ._log_cdf (value , ** kwargs )
123
+
124
+ def _cdf (self , value , ** kwargs ):
125
+ return self .tensor_distribution ._cdf (value , ** kwargs )
126
+
127
+ def _log_survival_function (self , value , ** kwargs ):
128
+ return self .tensor_distribution ._log_survival_function (value , ** kwargs )
129
+
130
+ def _survival_function (self , value , ** kwargs ):
131
+ return self .tensor_distribution ._survival_function (value , ** kwargs )
132
+
133
+ def _entropy (self , ** kwargs ):
134
+ return self .tensor_distribution ._entropy (** kwargs )
135
+
136
+ def _mean (self , ** kwargs ):
137
+ return self .tensor_distribution ._mean (** kwargs )
138
+
139
+ def _quantile (self , value , ** kwargs ):
140
+ return self .tensor_distribution ._quantile (value , ** kwargs )
141
+
142
+ def _variance (self , ** kwargs ):
143
+ return self .tensor_distribution ._variance (** kwargs )
144
+
145
+ def _stddev (self , ** kwargs ):
146
+ return self .tensor_distribution ._stddev (** kwargs )
147
+
148
+ def _covariance (self , ** kwargs ):
149
+ return self .tensor_distribution ._covariance (** kwargs )
150
+
151
+ def _mode (self , ** kwargs ):
152
+ return self .tensor_distribution ._mode (** kwargs )
153
+
154
+ def _default_event_space_bijector (self , * args , ** kwargs ):
155
+ return self .tensor_distribution ._default_event_space_bijector (
156
+ * args , ** kwargs )
157
+
158
+ def _parameter_control_dependencies (self , is_init ):
159
+ return self .tensor_distribution ._parameter_control_dependencies (is_init )
80
160
81
161
@property
82
162
def shape (self ):
163
+ return self ._shape
164
+
165
+ def _shape (self ):
83
166
return (tf .TensorShape (None ) if self ._concrete_value is None
84
167
else self ._concrete_value .shape )
85
168
@@ -130,15 +213,26 @@ def _value(self, dtype=None, name=None, as_ref=False):
130
213
' results in `tf.convert_to_tensor(x)` being identical to '
131
214
'`x.mean()`.' .format (type (self ), self ))
132
215
with self ._name_and_control_scope ('value' ):
133
- self ._concrete_value = (self ._convert_to_tensor_fn (self )
134
- if callable (self ._convert_to_tensor_fn )
135
- else self ._convert_to_tensor_fn )
216
+ self ._concrete_value = (
217
+ self ._convert_to_tensor_fn (self .tensor_distribution )
218
+ if callable (self ._convert_to_tensor_fn )
219
+ else self ._convert_to_tensor_fn )
136
220
if (not tf .is_tensor (self ._concrete_value ) and
137
221
not isinstance (self ._concrete_value ,
138
222
composite_tensor .CompositeTensor )):
139
223
self ._concrete_value = nest_util .convert_to_nested_tensor ( # pylint: disable=protected-access
140
224
self ._concrete_value ,
141
225
name = name or 'concrete_value' ,
142
226
dtype = dtype ,
143
- dtype_hint = self .dtype )
227
+ dtype_hint = self .tensor_distribution . dtype )
144
228
return self ._concrete_value
229
+
230
+
231
+ @kullback_leibler .RegisterKL (_TensorCoercible , tfd .Distribution )
232
+ def _kl_tensor_coercible_distribution (a , b , name = None ):
233
+ return kullback_leibler .kl_divergence (a .tensor_distribution , b , name = name )
234
+
235
+
236
+ @kullback_leibler .RegisterKL (tfd .Distribution , _TensorCoercible )
237
+ def _kl_distribution_tensor_coercible (a , b , name = None ):
238
+ return kullback_leibler .kl_divergence (a , b .tensor_distribution , name = name )
0 commit comments