Skip to content

Commit 0fc11b3

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
BREAKING CHANGE: Make _TensorCoercible meta-distribution-like, such that instances no longer share the __class__ of their instantiating distribution, in preparation to convert most TFP distributions to CompositeTensor.
This means that distributions output from `tfp.layers` are instances of `_TensorCoercible` and not instances of the TFP library distribution with which they were constructed. For example: ``` d = 4 p = tfpl.MultivariateNormalTriL.params_size(d) layer = tfpl.MultivariateNormalTriL(d, tfd.Distribution.mean) t = tfd.Normal(0, 1).sample([2, 3, p], seed=42) x = layer(t) # Newly fails; x is a `_TensorCoercible` instance. assert isinstance(x, tfd.MultivariateNormalTriL) # Still works: attributes of the inner `MultivariateNormalTriL` are accessible. x.loc ``` PiperOrigin-RevId: 378224934
1 parent 4132455 commit 0fc11b3

File tree

6 files changed

+186
-58
lines changed

6 files changed

+186
-58
lines changed

tensorflow_probability/python/layers/distribution_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def _fn(*fargs, **fkwargs):
193193
value.shape = value[-1].shape
194194
value.get_shape = value[-1].get_shape
195195
value.dtype = value[-1].dtype
196-
distribution.shape = value[-1].shape
196+
distribution._shape = value[-1].shape # pylint: disable=protected-access
197197
distribution.get_shape = value[-1].get_shape
198198
else:
199-
distribution.shape = value.shape
199+
distribution._shape = value.shape # pylint: disable=protected-access
200200
distribution.get_shape = value.get_shape
201201
return distribution, value
202202

tensorflow_probability/python/layers/distribution_layer_test.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def _vec_pad(x, value=0):
5050
return tf.pad(x, paddings=paddings, constant_values=value)
5151

5252

53+
def _unwrap_tensor_coercible(dist):
54+
inner_dist = getattr(dist, 'tensor_distribution', dist)
55+
if inner_dist is dist:
56+
return inner_dist
57+
return _unwrap_tensor_coercible(inner_dist)
58+
59+
5360
# TODO(b/143642032): Figure out how to solve issues with save/load, so that we
5461
# can decorate all of these tests with @test_util.test_all_tf_execution_regimes
5562
@test_util.test_graph_and_eager_modes
@@ -142,8 +149,8 @@ def accuracy(x, rv_x):
142149
validation_data=(self.x_test, self.x_test),
143150
shuffle=True)
144151
yhat = vae_model(tf.convert_to_tensor(self.x_test))
145-
self.assertIsInstance(yhat, tfd.Independent)
146-
self.assertIsInstance(yhat.distribution, tfd.Bernoulli)
152+
self.assertIsInstance(yhat.tensor_distribution, tfd.Independent)
153+
self.assertIsInstance(yhat.tensor_distribution.distribution, tfd.Bernoulli)
147154

148155
def test_keras_functional_api(self):
149156
"""Test `DistributionLambda`s are composable via Keras functional API."""
@@ -193,8 +200,8 @@ def test_keras_functional_api(self):
193200
validation_data=(self.x_test, self.x_test),
194201
shuffle=True)
195202
yhat = vae_model(tf.convert_to_tensor(self.x_test))
196-
self.assertIsInstance(yhat, tfd.Independent)
197-
self.assertIsInstance(yhat.distribution, tfd.Bernoulli)
203+
self.assertIsInstance(yhat.tensor_distribution, tfd.Independent)
204+
self.assertIsInstance(yhat.tensor_distribution.distribution, tfd.Bernoulli)
198205

199206
def test_keras_model_api(self):
200207
"""Test `DistributionLambda`s are composable via Keras `Model` API."""
@@ -249,8 +256,8 @@ def call(self, inputs):
249256
epochs=1,
250257
validation_data=(self.x_test, self.x_test))
251258
yhat = vae_model(tf.convert_to_tensor(self.x_test))
252-
self.assertIsInstance(yhat, tfd.Independent)
253-
self.assertIsInstance(yhat.distribution, tfd.Bernoulli)
259+
self.assertIsInstance(yhat.tensor_distribution, tfd.Independent)
260+
self.assertIsInstance(yhat.tensor_distribution.distribution, tfd.Bernoulli)
254261

255262
def test_keras_sequential_api_multiple_draws(self):
256263
num_draws = 2
@@ -293,8 +300,8 @@ def test_keras_sequential_api_multiple_draws(self):
293300
steps_per_epoch=1, # Usually `n // batch_size`.
294301
validation_data=(self.x_test, self.x_test))
295302
yhat = vae_model(tf.convert_to_tensor(self.x_test))
296-
self.assertIsInstance(yhat, tfd.Independent)
297-
self.assertIsInstance(yhat.distribution, tfd.Bernoulli)
303+
self.assertIsInstance(yhat.tensor_distribution, tfd.Independent)
304+
self.assertIsInstance(yhat.tensor_distribution.distribution, tfd.Bernoulli)
298305

299306
def test_side_variable_is_auto_tracked(self):
300307
# `s` is the "side variable".
@@ -587,7 +594,7 @@ def test_layer(self):
587594
layer = tfpl.MultivariateNormalTriL(d, tfd.Distribution.mean)
588595
t = tfd.Normal(0, 1).sample([2, 3, p], seed=42)
589596
x = layer(t)
590-
self._check_distribution(t, x)
597+
self._check_distribution(t, x.tensor_distribution)
591598

592599
def test_doc_string(self):
593600
# Load data.
@@ -654,7 +661,7 @@ def test_layer(self):
654661
layer = tfpl.OneHotCategorical(d, validate_args=True)
655662
t = tfd.Normal(0, 1).sample([2, 3, p], seed=42)
656663
x = layer(t)
657-
self._check_distribution(t, x)
664+
self._check_distribution(t, x.tensor_distribution)
658665

659666
def test_doc_string(self):
660667
# Load data.
@@ -692,9 +699,11 @@ def test_doc_string(self):
692699
class CategoricalMixtureOfOneHotCategoricalTest(test_util.TestCase):
693700

694701
def _check_distribution(self, t, x):
695-
self.assertIsInstance(x, tfd.MixtureSameFamily)
696-
self.assertIsInstance(x.mixture_distribution, tfd.Categorical)
697-
self.assertIsInstance(x.components_distribution, tfd.OneHotCategorical)
702+
self.assertIsInstance(_unwrap_tensor_coercible(x), tfd.MixtureSameFamily)
703+
self.assertIsInstance(_unwrap_tensor_coercible(x.mixture_distribution),
704+
tfd.Categorical)
705+
self.assertIsInstance(_unwrap_tensor_coercible(x.components_distribution),
706+
tfd.OneHotCategorical)
698707
t_back = tf.concat([
699708
x.mixture_distribution.logits,
700709
tf.reshape(x.components_distribution.logits, shape=[2, 3, -1]),
@@ -768,9 +777,12 @@ def test_doc_string(self):
768777
shuffle=True)
769778

770779
yhat = model(x)
771-
self.assertIsInstance(yhat, tfd.MixtureSameFamily)
772-
self.assertIsInstance(yhat.mixture_distribution, tfd.Categorical)
773-
self.assertIsInstance(yhat.components_distribution, tfd.OneHotCategorical)
780+
self.assertIsInstance(_unwrap_tensor_coercible(yhat), tfd.MixtureSameFamily)
781+
self.assertIsInstance(
782+
_unwrap_tensor_coercible(yhat.mixture_distribution), tfd.Categorical)
783+
self.assertIsInstance(
784+
_unwrap_tensor_coercible(yhat.components_distribution),
785+
tfd.OneHotCategorical)
774786
# TODO(b/120221303): For now we just check that the code executes and we get
775787
# back a distribution instance. Better would be to change the data
776788
# generation so the model becomes well-specified (and we can check correctly
@@ -834,7 +846,7 @@ def test_layer(self):
834846

835847
layer = self.layer_class(validate_args=True, dtype=self.dtype)
836848
x = layer(t)
837-
self._check_distribution(t, x, batch_shape)
849+
self._check_distribution(t, x.tensor_distribution, batch_shape)
838850

839851
def test_serialization(self):
840852
event_shape = []
@@ -1163,11 +1175,14 @@ def _build_tensor(self, ndarray, dtype=None):
11631175
ndarray, shape=ndarray.shape if self.use_static_shape else None)
11641176

11651177
def _check_distribution(self, t, x, batch_shape):
1166-
self.assertIsInstance(x, tfd.MixtureSameFamily)
1167-
self.assertIsInstance(x.mixture_distribution, tfd.Categorical)
1168-
self.assertIsInstance(x.components_distribution, tfd.Independent)
1169-
self.assertIsInstance(x.components_distribution.distribution,
1170-
self.dist_class)
1178+
self.assertIsInstance(_unwrap_tensor_coercible(x), tfd.MixtureSameFamily)
1179+
self.assertIsInstance(
1180+
_unwrap_tensor_coercible(x.mixture_distribution), tfd.Categorical)
1181+
self.assertIsInstance(
1182+
_unwrap_tensor_coercible(x.components_distribution), tfd.Independent)
1183+
self.assertIsInstance(
1184+
_unwrap_tensor_coercible(x.components_distribution.distribution),
1185+
self.dist_class)
11711186
self.assertEqual(self.dtype, x.dtype)
11721187

11731188
t_back = self._distribution_to_params(x, batch_shape)
@@ -1413,9 +1428,12 @@ def _build_tensor(self, ndarray, dtype=None):
14131428
ndarray, shape=ndarray.shape if self.use_static_shape else None)
14141429

14151430
def _check_distribution(self, t, x, batch_shape):
1416-
self.assertIsInstance(x, tfd.MixtureSameFamily)
1417-
self.assertIsInstance(x.mixture_distribution, tfd.Categorical)
1418-
self.assertIsInstance(x.components_distribution, tfd.MultivariateNormalTriL)
1431+
self.assertIsInstance(_unwrap_tensor_coercible(x), tfd.MixtureSameFamily)
1432+
self.assertIsInstance(
1433+
_unwrap_tensor_coercible(x.mixture_distribution), tfd.Categorical)
1434+
self.assertIsInstance(
1435+
_unwrap_tensor_coercible(x.components_distribution),
1436+
tfd.MultivariateNormalTriL)
14191437

14201438
shape = tf.concat([batch_shape, [-1]], axis=0)
14211439
batch_and_n_shape = tf.concat(

tensorflow_probability/python/layers/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ py_library(
4040
deps = [
4141
# tensorflow dep,
4242
"//tensorflow_probability/python/distributions:distribution",
43+
"//tensorflow_probability/python/distributions:kullback_leibler",
4344
"//tensorflow_probability/python/internal:nest_util",
4445
"//tensorflow_probability/python/internal:parameter_properties",
4546
"//tensorflow_probability/python/util",

tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import copy
2221
import six
2322

2423
import tensorflow.compat.v2 as tf
2524

2625
from tensorflow_probability.python.distributions import distribution as tfd
26+
from tensorflow_probability.python.distributions import kullback_leibler
2727
from tensorflow_probability.python.internal import nest_util
2828
from tensorflow_probability.python.internal import parameter_properties
2929
from tensorflow_probability.python.util.deferred_tensor import TensorMetaClass
3030
from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
31+
from tensorflow.python.training.tracking import data_structures # pylint: disable=g-direct-tensorflow-import
3132

3233

3334
__all__ = [] # We intend nothing public.
3435

36+
_NOT_FOUND = object()
37+
3538

3639
# Define mixin type because Distribution already has its own metaclass.
3740
class _DistributionAndTensorCoercibleMeta(type(tfd.Distribution),
@@ -43,43 +46,123 @@ class _DistributionAndTensorCoercibleMeta(type(tfd.Distribution),
4346
class _TensorCoercible(tfd.Distribution):
4447
"""Docstring."""
4548

46-
registered_class_list = {}
47-
48-
def __new__(cls, distribution, convert_to_tensor_fn=tfd.Distribution.sample):
49-
if isinstance(distribution, cls):
50-
return distribution
51-
if not isinstance(distribution, tfd.Distribution):
52-
raise TypeError('`distribution` argument must be a '
53-
'`tfd.Distribution` instance; '
54-
'saw "{}" of type "{}".'.format(
55-
distribution, type(distribution)))
56-
self = copy.copy(distribution)
57-
distcls = distribution.__class__
58-
self_class = _TensorCoercible.registered_class_list.get(distcls)
59-
if not self_class:
60-
self_class = type(distcls.__name__, (cls, distcls), {})
61-
_TensorCoercible.registered_class_list[distcls] = self_class
62-
self.__class__ = self_class
63-
return self
64-
6549
def __init__(self,
6650
distribution,
6751
convert_to_tensor_fn=tfd.Distribution.sample):
6852
self._concrete_value = None # pylint: disable=protected-access
6953
self._convert_to_tensor_fn = convert_to_tensor_fn # pylint: disable=protected-access
54+
self.tensor_distribution = distribution
55+
super(_TensorCoercible, self).__init__(
56+
dtype=distribution.dtype,
57+
reparameterization_type=distribution.reparameterization_type,
58+
validate_args=distribution.validate_args,
59+
allow_nan_stats=distribution.allow_nan_stats,
60+
parameters=distribution.parameters)
61+
62+
def __setattr__(self, name, value):
63+
"""Support self.foo = trackable syntax.
64+
65+
Redefined from `tensorflow/python/training/tracking/tracking.py` to avoid
66+
calling `getattr`, which causes an infinite loop.
67+
68+
Args:
69+
name: str, name of the attribute to be set.
70+
value: value to be set.
71+
"""
72+
if vars(self).get(name, _NOT_FOUND) is value:
73+
return
74+
75+
if vars(self).get('_self_setattr_tracking', True):
76+
value = data_structures.sticky_attribute_assignment(
77+
trackable=self, value=value, name=name)
78+
object.__setattr__(self, name, value)
79+
80+
def __getattr__(self, name):
81+
# If the attribute is set in the _TensorCoercible object, return it. This
82+
# ensures that direct calls to `getattr` behave as expected.
83+
if name in vars(self):
84+
return vars(self)[name]
85+
# Look for the attribute in `tensor_distribution`, unless it's a `_tracking`
86+
# attribute accessed directly by `getattr` in the `Trackable` base class, in
87+
# which case the default passed to `getattr` should be returned.
88+
if 'tensor_distribution' in vars(self) and '_tracking' not in name:
89+
return getattr(vars(self)['tensor_distribution'], name)
90+
# Otherwise invoke `__getattribute__`, which will return the default passed
91+
# to `getattr` if the attribute was not found.
92+
return self.__getattribute__(name)
7093

7194
@classmethod
7295
def _parameter_properties(cls, dtype, num_classes=None):
7396
return dict(distribution=parameter_properties.BatchedComponentProperties())
7497

98+
# pylint: disable=protected-access
7599
def _batch_shape_tensor(self, **parameter_kwargs):
76-
# Any parameter kwargs are for the inner distribution, so pass them
77-
# to its `_batch_shape_tensor` method instead of handling them directly.
78-
return self.parameters['distribution']._batch_shape_tensor( # pylint: disable=protected-access
79-
**parameter_kwargs)
100+
return self.tensor_distribution._batch_shape_tensor(**parameter_kwargs)
101+
102+
def _batch_shape(self):
103+
return self.tensor_distribution._batch_shape()
104+
105+
def _event_shape_tensor(self):
106+
return self.tensor_distribution._event_shape_tensor()
107+
108+
def _event_shape(self):
109+
return self.tensor_distribution._event_shape()
110+
111+
def sample(self, sample_shape=(), seed=None, name='sample', **kwargs):
112+
return self.tensor_distribution.sample(
113+
sample_shape=sample_shape, seed=seed, name=name, **kwargs)
114+
115+
def _log_prob(self, value, **kwargs):
116+
return self.tensor_distribution._log_prob(value, **kwargs)
117+
118+
def _prob(self, value, **kwargs):
119+
return self.tensor_distribution._prob(value, **kwargs)
120+
121+
def _log_cdf(self, value, **kwargs):
122+
return self.tensor_distribution._log_cdf(value, **kwargs)
123+
124+
def _cdf(self, value, **kwargs):
125+
return self.tensor_distribution._cdf(value, **kwargs)
126+
127+
def _log_survival_function(self, value, **kwargs):
128+
return self.tensor_distribution._log_survival_function(value, **kwargs)
129+
130+
def _survival_function(self, value, **kwargs):
131+
return self.tensor_distribution._survival_function(value, **kwargs)
132+
133+
def _entropy(self, **kwargs):
134+
return self.tensor_distribution._entropy(**kwargs)
135+
136+
def _mean(self, **kwargs):
137+
return self.tensor_distribution._mean(**kwargs)
138+
139+
def _quantile(self, value, **kwargs):
140+
return self.tensor_distribution._quantile(value, **kwargs)
141+
142+
def _variance(self, **kwargs):
143+
return self.tensor_distribution._variance(**kwargs)
144+
145+
def _stddev(self, **kwargs):
146+
return self.tensor_distribution._stddev(**kwargs)
147+
148+
def _covariance(self, **kwargs):
149+
return self.tensor_distribution._covariance(**kwargs)
150+
151+
def _mode(self, **kwargs):
152+
return self.tensor_distribution._mode(**kwargs)
153+
154+
def _default_event_space_bijector(self, *args, **kwargs):
155+
return self.tensor_distribution._default_event_space_bijector(
156+
*args, **kwargs)
157+
158+
def _parameter_control_dependencies(self, is_init):
159+
return self.tensor_distribution._parameter_control_dependencies(is_init)
80160

81161
@property
82162
def shape(self):
163+
return self._shape
164+
165+
def _shape(self):
83166
return (tf.TensorShape(None) if self._concrete_value is None
84167
else self._concrete_value.shape)
85168

@@ -130,15 +213,26 @@ def _value(self, dtype=None, name=None, as_ref=False):
130213
' results in `tf.convert_to_tensor(x)` being identical to '
131214
'`x.mean()`.'.format(type(self), self))
132215
with self._name_and_control_scope('value'):
133-
self._concrete_value = (self._convert_to_tensor_fn(self)
134-
if callable(self._convert_to_tensor_fn)
135-
else self._convert_to_tensor_fn)
216+
self._concrete_value = (
217+
self._convert_to_tensor_fn(self.tensor_distribution)
218+
if callable(self._convert_to_tensor_fn)
219+
else self._convert_to_tensor_fn)
136220
if (not tf.is_tensor(self._concrete_value) and
137221
not isinstance(self._concrete_value,
138222
composite_tensor.CompositeTensor)):
139223
self._concrete_value = nest_util.convert_to_nested_tensor( # pylint: disable=protected-access
140224
self._concrete_value,
141225
name=name or 'concrete_value',
142226
dtype=dtype,
143-
dtype_hint=self.dtype)
227+
dtype_hint=self.tensor_distribution.dtype)
144228
return self._concrete_value
229+
230+
231+
@kullback_leibler.RegisterKL(_TensorCoercible, tfd.Distribution)
232+
def _kl_tensor_coercible_distribution(a, b, name=None):
233+
return kullback_leibler.kl_divergence(a.tensor_distribution, b, name=name)
234+
235+
236+
@kullback_leibler.RegisterKL(tfd.Distribution, _TensorCoercible)
237+
def _kl_distribution_tensor_coercible(a, b, name=None):
238+
return kullback_leibler.kl_divergence(a, b.tensor_distribution, name=name)

0 commit comments

Comments
 (0)