Skip to content

Commit d3398a7

Browse files
davmretensorflower-gardener
authored andcommitted
Extract static JD attributes via tracing.
The attributes are cached, so that a model is only traced once. This is intended to limit the need to actually run the model in `self._get_single_sample_distributions()`. This change also includes tweaks to enable numpy-mode tests for joint distributions. PiperOrigin-RevId: 380928836
1 parent 47407b8 commit d3398a7

11 files changed

+180
-60
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3100,7 +3100,6 @@ multi_substrate_py_test(
31003100
name = "joint_distribution_sequential_test",
31013101
size = "medium",
31023102
srcs = ["joint_distribution_sequential_test.py"],
3103-
numpy_tags = ["notap"],
31043103
shard_count = 2,
31053104
deps = [
31063105
# absl/testing:parameterized dep,

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from tensorflow_probability.python.distributions import distribution as distribution_lib
3434
from tensorflow_probability.python.distributions import log_prob_ratio
3535
from tensorflow_probability.python.internal import assert_util
36+
from tensorflow_probability.python.internal import auto_composite_tensor
37+
from tensorflow_probability.python.internal import callable_util
3638
from tensorflow_probability.python.internal import distribution_util
3739
from tensorflow_probability.python.internal import docstring_util
3840
from tensorflow_probability.python.internal import nest_util
@@ -53,6 +55,38 @@
5355
JAX_MODE = False
5456

5557

58+
@auto_composite_tensor.auto_composite_tensor
59+
class StaticDistributionAttributes(auto_composite_tensor.AutoCompositeTensor):
60+
"""Container to smuggle static attributes out of a tf.function trace."""
61+
62+
def __init__(self,
63+
batch_shape,
64+
dtype,
65+
event_shape,
66+
experimental_shard_axis_names,
67+
name,
68+
reparameterization_type):
69+
self.batch_shape = batch_shape
70+
self.dtype = dtype
71+
self.event_shape = event_shape
72+
self.experimental_shard_axis_names = experimental_shard_axis_names
73+
self.name = name
74+
self.reparameterization_type = reparameterization_type
75+
76+
def __iter__(self):
77+
"""Yields parameters in order matching __init__ signature."""
78+
return iter((self.batch_shape, self.dtype, self.event_shape,
79+
self.experimental_shard_axis_names, self.name,
80+
self.reparameterization_type))
81+
82+
if JAX_MODE:
83+
from jax import tree_util # pylint: disable=g-import-not-at-top
84+
tree_util.register_pytree_node(
85+
StaticDistributionAttributes,
86+
flatten_func=lambda sda: ([], list(sda)),
87+
unflatten_func=lambda attrs, _: StaticDistributionAttributes(*attrs))
88+
89+
5690
class ValueWithTrace(collections.namedtuple(
5791
'ValueWithTrace',
5892
['value', 'traced'])):
@@ -119,6 +153,22 @@ def trace_values_and_log_probs(dist, sample_shape, seed, value=None):
119153
return ValueWithTrace(value=value, traced=(value, lp))
120154

121155

156+
def trace_static_attributes(dist, sample_shape, seed, value):
157+
"""Extracts the current distribution's static attributes as Tensor specs."""
158+
del sample_shape
159+
if value is None:
160+
value = dist.sample(seed=seed)
161+
return ValueWithTrace(
162+
value=value,
163+
traced=StaticDistributionAttributes(
164+
batch_shape=dist.batch_shape,
165+
dtype=dist.dtype,
166+
experimental_shard_axis_names=dist.experimental_shard_axis_names,
167+
event_shape=dist.event_shape,
168+
name=get_explicit_name_for_component(dist),
169+
reparameterization_type=dist.reparameterization_type))
170+
171+
122172
CALLING_CONVENTION_DESCRIPTION = """
123173
The measure methods of `JointDistribution` (`log_prob`, `prob`, etc.)
124174
can be called either by passing a single structure of tensors or by using
@@ -269,6 +319,17 @@ def _get_single_sample_distributions(self, candidate_dists=None):
269319
self._single_sample_distributions[graph_id] = ds
270320
return ds
271321

322+
def _get_static_distribution_attributes(self, seed=None):
323+
if not hasattr(self, '_cached_static_attributes'):
324+
flat_list_of_static_attributes = callable_util.get_output_spec(
325+
lambda: self._execute_model( # pylint: disable=g-long-lambda
326+
sample_and_trace_fn=trace_static_attributes,
327+
seed=seed if seed is not None else samplers.zeros_seed()))
328+
self._cached_static_attributes = StaticDistributionAttributes(
329+
*zip(*flat_list_of_static_attributes))
330+
331+
return self._cached_static_attributes
332+
272333
# Override `tf.Module`'s `_flatten` method to ensure that distributions are
273334
# instantiated, so that accessing `.variables` or `.trainable_variables` gives
274335
# consistent results.
@@ -287,8 +348,8 @@ def _model_flatten(self, xs):
287348
@property
288349
def dtype(self):
289350
"""The `DType` of `Tensor`s handled by this `Distribution`."""
290-
return self._model_unflatten([
291-
d.dtype for d in self._get_single_sample_distributions()])
351+
return self._model_unflatten(
352+
self._get_static_distribution_attributes().dtype)
292353

293354
@property
294355
def reparameterization_type(self):
@@ -301,37 +362,31 @@ def reparameterization_type(self):
301362
reparameterization_type: `ReparameterizationType` of each distribution in
302363
`model`.
303364
"""
304-
return self._model_unflatten([
305-
d.reparameterization_type
306-
for d in self._get_single_sample_distributions()])
365+
return self._model_unflatten(
366+
self._get_static_distribution_attributes().reparameterization_type)
307367

308368
@property
309369
def experimental_shard_axis_names(self):
310370
"""Indicates whether part distributions have active shard axis names."""
311-
return self._model_unflatten([
312-
d.experimental_shard_axis_names
313-
for d in self._get_single_sample_distributions()])
371+
return self._model_unflatten(
372+
self._get_static_distribution_attributes().
373+
experimental_shard_axis_names)
314374

315375
@property
316376
def use_vectorized_map(self):
317377
return False
318378

319379
def _batch_shape(self):
320-
return self._model_unflatten([
321-
d.batch_shape for d in self._get_single_sample_distributions()])
380+
return self._model_unflatten(
381+
self._get_static_distribution_attributes().batch_shape)
322382

323383
def _batch_shape_tensor(self):
324384
return self._model_unflatten(
325385
self._map_attr_over_dists('batch_shape_tensor'))
326386

327387
def _event_shape(self):
328-
if not hasattr(self, '_cached_event_shape'):
329-
self._cached_event_shape = [
330-
d.event_shape
331-
for d in self._get_single_sample_distributions()]
332-
# Unflattening *after* retrieving from cache prevents tf.Module from
333-
# wrapping the returned value.
334-
return self._model_unflatten(self._cached_event_shape)
388+
return self._model_unflatten(
389+
self._get_static_distribution_attributes().event_shape)
335390

336391
def _event_shape_tensor(self):
337392
return self._model_unflatten(
@@ -363,6 +418,11 @@ def sample_distributions(self, sample_shape=(), seed=None, value=None,
363418
samples: a `tuple` of `Tensor`s with prepended dimensions `sample_shape`
364419
for each of `distribution_fn`.
365420
"""
421+
# Use the user-provided seed to trace static distribution attributes, if
422+
# they're not already cached. This ensures we don't try to pass a stateless
423+
# seed to a stateful sampler, or vice versa.
424+
self._get_static_distribution_attributes(seed=seed)
425+
366426
with self._name_and_control_scope(name):
367427
value = self._resolve_value(value=value,
368428
allow_partially_specified=True,
@@ -516,6 +576,11 @@ def _unnormalized_log_prob(self, value):
516576
'corresponding distribution. Default value: `None` '
517577
'(i.e., draw a sample from each distribution).')})
518578
def _sample_n(self, sample_shape, seed, value=None):
579+
# Use the user-provided seed to trace static distribution attributes, if
580+
# they're not already cached. This ensures we don't try to pass a stateless
581+
# seed to a stateful sampler, or vice versa.
582+
self._get_static_distribution_attributes(seed=seed)
583+
519584
might_have_batch_dims = (
520585
distribution_util.shape_may_be_nontrivial(sample_shape)
521586
or value is not None)
@@ -539,6 +604,11 @@ def _sample_n(self, sample_shape, seed, value=None):
539604

540605
# TODO(b/189122177): Implement _sample_and_log_prob for distributed JDs.
541606
def _sample_and_log_prob(self, sample_shape, seed, value=None, **kwargs):
607+
# Use the user-provided seed to trace static distribution attributes, if
608+
# they're not already cached. This ensures we don't try to pass a stateless
609+
# seed to a stateful sampler, or vice versa.
610+
self._get_static_distribution_attributes(seed=seed)
611+
542612
xs, lps = zip(
543613
*self._call_execute_model(
544614
sample_shape,
@@ -673,8 +743,8 @@ def _flat_resolve_names(self, dummy_name='var'):
673743
"""Resolves a name for each random variable in the model."""
674744
names = []
675745
names_used = set()
676-
for dummy_idx, d in enumerate(self._get_single_sample_distributions()):
677-
name = get_explicit_name_for_component(d)
746+
for dummy_idx, name in enumerate(
747+
self._get_static_distribution_attributes().name):
678748
if name is None:
679749
name = '{}{}'.format(dummy_name, dummy_idx)
680750
if name in names_used:

tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -523,19 +523,13 @@ def dist():
523523
[value_partial_batch_dim, num_rows, num_columns])
524524

525525
def test_unit_sample_shape_avoids_vectorization(self):
526-
if not tf.executing_eagerly():
527-
self.skipTest('Test relies on eager execution.')
528-
526+
xs = [] # Collect (possibly symbolic) Tensors sampled inside the model.
529527
@tfd.JointDistributionCoroutineAutoBatched
530528
def dist():
531-
# Because `pfor` operates by tracing its loop body, to ensure we're
532-
# not inside of a `pfor` loop body it's sufficient to check that we're
533-
# not inside of a tf.function.
534-
if not tf.executing_eagerly():
535-
raise ValueError('Model is running inside tf.function. This may '
536-
'indicate that auto-vectorization is being '
537-
'triggered unnecessarily.')
538-
yield tfd.Normal(0., 1., name='x')
529+
x = yield tfd.Normal(0., 1., name='x')
530+
xs.append(x)
531+
532+
# Try sampling with a variety of unit sample shapes.
539533
self.assertEqual(
540534
[1],
541535
dist.sample(
@@ -549,6 +543,10 @@ def dist():
549543
dist.sample([1, 1],
550544
seed=test_util.test_seed(sampler_type='seedless')).x.shape)
551545

546+
# Check that the model only ever saw the trivial sample shape.
547+
for x in xs:
548+
self.assertEqual(x.shape, [])
549+
552550
def test_unit_sample_shape(self):
553551
@tfd.JointDistributionCoroutineAutoBatched
554552
def dist():

tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,37 @@ def desired_unnorm_lp(cprior, c1, c0):
13951395
tfp.math.value_and_gradient(lp_fn, (cprior, c1, c0))[1],
13961396
tfp.math.value_and_gradient(ulp_fn, (cprior, c1, c0))[1])
13971397

1398+
@test_util.numpy_disable_test_missing_functionality('symbolic tracing')
1399+
@test_util.jax_disable_test_missing_functionality(
1400+
'https://github.com/google/jax/issues/7011')
1401+
def test_symbolic_trace_dtype(self):
1402+
# A model that will definitely OOM. (1 billion squared floats).
1403+
@tfd.JointDistributionCoroutine
1404+
def model():
1405+
x = yield Root(tfd.MultivariateNormalDiag(
1406+
tf.zeros(int(1e9)), tf.ones(int(1e9)), name='x'))
1407+
loc = tf.einsum('i,j->ij', x, x)
1408+
yield tfd.Independent(
1409+
tfd.MultivariateNormalDiag(loc, tf.ones(int(1e9))),
1410+
reinterpreted_batch_ndims=1,
1411+
name='y')
1412+
self.assertEqual((tf.float32, tf.float32), model.dtype)
1413+
1414+
@test_util.numpy_disable_test_missing_functionality('symbolic tracing')
1415+
def test_symbolic_trace_is_cached(self):
1416+
model_executions = []
1417+
1418+
@tfd.JointDistributionCoroutine
1419+
def model():
1420+
x = yield Root(tfd.Normal(0., 1., name='x'))
1421+
y = yield tfd.Normal(x, 1., name='y')
1422+
model_executions.append(y)
1423+
1424+
self.assertAllEqual(((), ()), model.event_shape)
1425+
self.assertAllEqual(((), ()), model.batch_shape)
1426+
self.assertAllEqual((tf.float32, tf.float32), model.dtype)
1427+
self.assertAllEqual(('x', 'y'), model._flat_resolve_names())
1428+
self.assertLen(model_executions, 1)
13981429

13991430
if __name__ == '__main__':
14001431
tf.test.main()

tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ def batch_ndims(self):
7878
return self._batch_ndims
7979

8080
def _batch_shape_parts(self):
81-
return [d.batch_shape[:self.batch_ndims]
82-
for d in self._get_single_sample_distributions()]
81+
return [batch_shape[:self.batch_ndims]
82+
for batch_shape in self._get_static_distribution_attributes().
83+
batch_shape]
8384

8485
def _batch_shape(self):
8586
# Caching will not leak graph Tensors since this is a static attribute.

tensorflow_probability/python/distributions/joint_distribution_sequential_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import tensorflow_probability as tfp
3030

3131
from tensorflow_probability.python.distributions import joint_distribution_sequential
32+
from tensorflow_probability.python.internal import prefer_static as ps
3233
from tensorflow_probability.python.internal import test_util
3334

3435
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
@@ -501,6 +502,7 @@ def test_matrix_factorization(self):
501502
self.assertEqual(lp.shape, [7, 9])
502503

503504
@test_util.jax_disable_variable_test
505+
@test_util.numpy_disable_variable_test
504506
def test_latent_dirichlet_allocation(self):
505507
"""Tests Latent Dirichlet Allocation joint model.
506508
@@ -587,8 +589,7 @@ def test_poisson_switchover_graphical_model(self):
587589
indices=tf.cast(
588590
tau[..., tf.newaxis] < tf.linspace(0., 1., n),
589591
dtype=tf.int32),
590-
# TODO(b/139204153): Remove static value hack after bug closed.
591-
batch_dims=int(tf.get_static_value(tf.rank(tau))))
592+
batch_dims=ps.rank(tau))
592593

593594
alpha = tf.math.reciprocal(tf.reduce_mean(count_data))
594595

tensorflow_probability/python/internal/backend/numpy/numpy_array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,9 @@ def _linspace(start, stop, num, name=None, axis=0): # pylint: disable=unused-ar
197197
if np.issubdtype(start.dtype, np.integer):
198198
start = start.astype(np.float64)
199199
stop = ops.convert_to_tensor(stop, dtype=start.dtype)
200-
num = ops.convert_to_tensor(num, dtype_hint=np.int32)
201-
if not np.issubdtype(num.dtype, np.integer):
200+
if not np.issubdtype(np.array(num).dtype, np.integer):
202201
raise TypeError('`num` must be an integer but got {}'.format(num.dtype))
203-
num = num.astype(np.int32)
204-
return np.linspace(start, stop, num, axis=axis).astype(start.dtype)
202+
return np.linspace(start, stop, int(num), axis=axis).astype(start.dtype)
205203

206204

207205
def _one_hot( # pylint: disable=unused-argument

tensorflow_probability/python/internal/backend/numpy/random_generators.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import print_function
2020

2121
import numpy as np
22-
import numpy as onp # Avoids JAX rewrite. # pylint: disable=reimported
2322

2423
from tensorflow_probability.python.internal.backend.numpy import _utils as utils
2524
from tensorflow_probability.python.internal.backend.numpy import ops
@@ -64,14 +63,10 @@ def _ensure_shape_tuple(t):
6463

6564

6665
def _bcast_shape(base_shape, args):
67-
base_shape = _ensure_shape_tuple(base_shape)
68-
if not args:
69-
return base_shape
70-
bc_arr = onp.zeros(base_shape + (0,))
66+
bcast_shape = _ensure_shape_tuple(base_shape)
7167
for arg in args:
72-
if arg is not None:
73-
bc_arr = bc_arr + onp.zeros(np.asarray(arg).shape + (0,))
74-
return bc_arr.shape[:-1]
68+
bcast_shape = ops.broadcast_shape(bcast_shape, np.asarray(arg).shape)
69+
return bcast_shape
7570

7671

7772
def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): # pylint: disable=unused-argument

tensorflow_probability/python/internal/backend/numpy/tensor_spec.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
class TensorSpec(object):
2323

24-
def __init__(self, *args, **kwargs):
25-
del args, kwargs
26-
self.dtype = None
24+
def __init__(self, shape, dtype):
25+
self.shape = shape
26+
self.dtype = dtype
27+
28+
def __repr__(self):
29+
return f'TensorSpec(shape={self.shape}, dtype={self.dtype})'

0 commit comments

Comments
 (0)