33
33
from tensorflow_probability .python .distributions import distribution as distribution_lib
34
34
from tensorflow_probability .python .distributions import log_prob_ratio
35
35
from tensorflow_probability .python .internal import assert_util
36
+ from tensorflow_probability .python .internal import auto_composite_tensor
37
+ from tensorflow_probability .python .internal import callable_util
36
38
from tensorflow_probability .python .internal import distribution_util
37
39
from tensorflow_probability .python .internal import docstring_util
38
40
from tensorflow_probability .python .internal import nest_util
53
55
JAX_MODE = False
54
56
55
57
58
+ @auto_composite_tensor .auto_composite_tensor
59
+ class StaticDistributionAttributes (auto_composite_tensor .AutoCompositeTensor ):
60
+ """Container to smuggle static attributes out of a tf.function trace."""
61
+
62
+ def __init__ (self ,
63
+ batch_shape ,
64
+ dtype ,
65
+ event_shape ,
66
+ experimental_shard_axis_names ,
67
+ name ,
68
+ reparameterization_type ):
69
+ self .batch_shape = batch_shape
70
+ self .dtype = dtype
71
+ self .event_shape = event_shape
72
+ self .experimental_shard_axis_names = experimental_shard_axis_names
73
+ self .name = name
74
+ self .reparameterization_type = reparameterization_type
75
+
76
+ def __iter__ (self ):
77
+ """Yields parameters in order matching __init__ signature."""
78
+ return iter ((self .batch_shape , self .dtype , self .event_shape ,
79
+ self .experimental_shard_axis_names , self .name ,
80
+ self .reparameterization_type ))
81
+
82
+ if JAX_MODE :
83
+ from jax import tree_util # pylint: disable=g-import-not-at-top
84
+ tree_util .register_pytree_node (
85
+ StaticDistributionAttributes ,
86
+ flatten_func = lambda sda : ([], list (sda )),
87
+ unflatten_func = lambda attrs , _ : StaticDistributionAttributes (* attrs ))
88
+
89
+
56
90
class ValueWithTrace (collections .namedtuple (
57
91
'ValueWithTrace' ,
58
92
['value' , 'traced' ])):
@@ -119,6 +153,22 @@ def trace_values_and_log_probs(dist, sample_shape, seed, value=None):
119
153
return ValueWithTrace (value = value , traced = (value , lp ))
120
154
121
155
156
+ def trace_static_attributes (dist , sample_shape , seed , value ):
157
+ """Extracts the current distribution's static attributes as Tensor specs."""
158
+ del sample_shape
159
+ if value is None :
160
+ value = dist .sample (seed = seed )
161
+ return ValueWithTrace (
162
+ value = value ,
163
+ traced = StaticDistributionAttributes (
164
+ batch_shape = dist .batch_shape ,
165
+ dtype = dist .dtype ,
166
+ experimental_shard_axis_names = dist .experimental_shard_axis_names ,
167
+ event_shape = dist .event_shape ,
168
+ name = get_explicit_name_for_component (dist ),
169
+ reparameterization_type = dist .reparameterization_type ))
170
+
171
+
122
172
CALLING_CONVENTION_DESCRIPTION = """
123
173
The measure methods of `JointDistribution` (`log_prob`, `prob`, etc.)
124
174
can be called either by passing a single structure of tensors or by using
@@ -269,6 +319,17 @@ def _get_single_sample_distributions(self, candidate_dists=None):
269
319
self ._single_sample_distributions [graph_id ] = ds
270
320
return ds
271
321
322
+ def _get_static_distribution_attributes (self , seed = None ):
323
+ if not hasattr (self , '_cached_static_attributes' ):
324
+ flat_list_of_static_attributes = callable_util .get_output_spec (
325
+ lambda : self ._execute_model ( # pylint: disable=g-long-lambda
326
+ sample_and_trace_fn = trace_static_attributes ,
327
+ seed = seed if seed is not None else samplers .zeros_seed ()))
328
+ self ._cached_static_attributes = StaticDistributionAttributes (
329
+ * zip (* flat_list_of_static_attributes ))
330
+
331
+ return self ._cached_static_attributes
332
+
272
333
# Override `tf.Module`'s `_flatten` method to ensure that distributions are
273
334
# instantiated, so that accessing `.variables` or `.trainable_variables` gives
274
335
# consistent results.
@@ -287,8 +348,8 @@ def _model_flatten(self, xs):
287
348
@property
288
349
def dtype (self ):
289
350
"""The `DType` of `Tensor`s handled by this `Distribution`."""
290
- return self ._model_unflatten ([
291
- d . dtype for d in self ._get_single_sample_distributions ()] )
351
+ return self ._model_unflatten (
352
+ self ._get_static_distribution_attributes (). dtype )
292
353
293
354
@property
294
355
def reparameterization_type (self ):
@@ -301,37 +362,31 @@ def reparameterization_type(self):
301
362
reparameterization_type: `ReparameterizationType` of each distribution in
302
363
`model`.
303
364
"""
304
- return self ._model_unflatten ([
305
- d .reparameterization_type
306
- for d in self ._get_single_sample_distributions ()])
365
+ return self ._model_unflatten (
366
+ self ._get_static_distribution_attributes ().reparameterization_type )
307
367
308
368
@property
309
369
def experimental_shard_axis_names (self ):
310
370
"""Indicates whether part distributions have active shard axis names."""
311
- return self ._model_unflatten ([
312
- d . experimental_shard_axis_names
313
- for d in self . _get_single_sample_distributions ()] )
371
+ return self ._model_unflatten (
372
+ self . _get_static_distribution_attributes ().
373
+ experimental_shard_axis_names )
314
374
315
375
@property
316
376
def use_vectorized_map (self ):
317
377
return False
318
378
319
379
def _batch_shape (self ):
320
- return self ._model_unflatten ([
321
- d . batch_shape for d in self ._get_single_sample_distributions ()] )
380
+ return self ._model_unflatten (
381
+ self ._get_static_distribution_attributes (). batch_shape )
322
382
323
383
def _batch_shape_tensor (self ):
324
384
return self ._model_unflatten (
325
385
self ._map_attr_over_dists ('batch_shape_tensor' ))
326
386
327
387
def _event_shape (self ):
328
- if not hasattr (self , '_cached_event_shape' ):
329
- self ._cached_event_shape = [
330
- d .event_shape
331
- for d in self ._get_single_sample_distributions ()]
332
- # Unflattening *after* retrieving from cache prevents tf.Module from
333
- # wrapping the returned value.
334
- return self ._model_unflatten (self ._cached_event_shape )
388
+ return self ._model_unflatten (
389
+ self ._get_static_distribution_attributes ().event_shape )
335
390
336
391
def _event_shape_tensor (self ):
337
392
return self ._model_unflatten (
@@ -363,6 +418,11 @@ def sample_distributions(self, sample_shape=(), seed=None, value=None,
363
418
samples: a `tuple` of `Tensor`s with prepended dimensions `sample_shape`
364
419
for each of `distribution_fn`.
365
420
"""
421
+ # Use the user-provided seed to trace static distribution attributes, if
422
+ # they're not already cached. This ensures we don't try to pass a stateless
423
+ # seed to a stateful sampler, or vice versa.
424
+ self ._get_static_distribution_attributes (seed = seed )
425
+
366
426
with self ._name_and_control_scope (name ):
367
427
value = self ._resolve_value (value = value ,
368
428
allow_partially_specified = True ,
@@ -516,6 +576,11 @@ def _unnormalized_log_prob(self, value):
516
576
'corresponding distribution. Default value: `None` '
517
577
'(i.e., draw a sample from each distribution).' )})
518
578
def _sample_n (self , sample_shape , seed , value = None ):
579
+ # Use the user-provided seed to trace static distribution attributes, if
580
+ # they're not already cached. This ensures we don't try to pass a stateless
581
+ # seed to a stateful sampler, or vice versa.
582
+ self ._get_static_distribution_attributes (seed = seed )
583
+
519
584
might_have_batch_dims = (
520
585
distribution_util .shape_may_be_nontrivial (sample_shape )
521
586
or value is not None )
@@ -539,6 +604,11 @@ def _sample_n(self, sample_shape, seed, value=None):
539
604
540
605
# TODO(b/189122177): Implement _sample_and_log_prob for distributed JDs.
541
606
def _sample_and_log_prob (self , sample_shape , seed , value = None , ** kwargs ):
607
+ # Use the user-provided seed to trace static distribution attributes, if
608
+ # they're not already cached. This ensures we don't try to pass a stateless
609
+ # seed to a stateful sampler, or vice versa.
610
+ self ._get_static_distribution_attributes (seed = seed )
611
+
542
612
xs , lps = zip (
543
613
* self ._call_execute_model (
544
614
sample_shape ,
@@ -673,8 +743,8 @@ def _flat_resolve_names(self, dummy_name='var'):
673
743
"""Resolves a name for each random variable in the model."""
674
744
names = []
675
745
names_used = set ()
676
- for dummy_idx , d in enumerate (self . _get_single_sample_distributions ()):
677
- name = get_explicit_name_for_component ( d )
746
+ for dummy_idx , name in enumerate (
747
+ self . _get_static_distribution_attributes (). name ):
678
748
if name is None :
679
749
name = '{}{}' .format (dummy_name , dummy_idx )
680
750
if name in names_used :
0 commit comments