35
35
from tensorflow_probability .python .internal import distribution_util
36
36
from tensorflow_probability .python .internal import docstring_util
37
37
from tensorflow_probability .python .internal import nest_util
38
- from tensorflow_probability .python .internal import prefer_static
38
+ from tensorflow_probability .python .internal import prefer_static as ps
39
39
from tensorflow_probability .python .internal import samplers
40
40
from tensorflow_probability .python .util .seed_stream import SeedStream
41
41
from tensorflow_probability .python .util .seed_stream import TENSOR_SEED_MSG_PREFIX
@@ -315,56 +315,15 @@ def experimental_shard_axis_names(self):
315
315
def use_vectorized_map (self ):
316
316
return False
317
317
318
- @property
319
- def batch_shape (self ):
320
- """Shape of a single sample from a single event index as a `TensorShape`.
321
-
322
- May be partially defined or unknown.
323
-
324
- The batch dimensions are indexes into independent, non-identical
325
- parameterizations of this distribution.
326
-
327
- Returns:
328
- batch_shape: `tuple` of `TensorShape`s representing the `batch_shape` for
329
- each distribution in `model`.
330
- """
318
+ def _batch_shape (self ):
331
319
return self ._model_unflatten ([
332
320
d .batch_shape for d in self ._get_single_sample_distributions ()])
333
321
334
- def batch_shape_tensor (self , sample_shape = (), name = 'batch_shape_tensor' ):
335
- """Shape of a single sample from a single event index as a 1-D `Tensor`.
336
-
337
- The batch dimensions are indexes into independent, non-identical
338
- parameterizations of this distribution.
339
-
340
- Args:
341
- sample_shape: The sample shape under which to evaluate the joint
342
- distribution. Sample shape at root (toplevel) nodes may affect the batch
343
- or event shapes of child nodes.
344
- name: name to give to the op
345
-
346
- Returns:
347
- batch_shape: `Tensor` representing batch shape of each distribution in
348
- `model`.
349
- """
350
- with self ._name_and_control_scope (name ):
351
- return self ._model_unflatten (
352
- self ._map_attr_over_dists (
353
- 'batch_shape_tensor' ,
354
- dists = (self .sample_distributions (sample_shape )
355
- if sample_shape else None )))
356
-
357
- @property
358
- def event_shape (self ):
359
- """Shape of a single sample from a single batch as a `TensorShape`.
322
+ def _batch_shape_tensor (self ):
323
+ return self ._model_unflatten (
324
+ self ._map_attr_over_dists ('batch_shape_tensor' ))
360
325
361
- May be partially defined or unknown.
362
-
363
- Returns:
364
- event_shape: `tuple` of `TensorShape`s representing the `event_shape` for
365
- each distribution in `model`.
366
- """
367
- # Caching will not leak graph Tensors since this is a static attribute.
326
+ def _event_shape (self ):
368
327
if not hasattr (self , '_cached_event_shape' ):
369
328
self ._cached_event_shape = [
370
329
d .event_shape
@@ -373,24 +332,9 @@ def event_shape(self):
373
332
# wrapping the returned value.
374
333
return self ._model_unflatten (self ._cached_event_shape )
375
334
376
- def event_shape_tensor (self , sample_shape = (), name = 'event_shape_tensor' ):
377
- """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
378
-
379
- Args:
380
- sample_shape: The sample shape under which to evaluate the joint
381
- distribution. Sample shape at root (toplevel) nodes may affect the batch
382
- or event shapes of child nodes.
383
- name: name to give to the op
384
- Returns:
385
- event_shape: `tuple` of `Tensor`s representing the `event_shape` for each
386
- distribution in `model`.
387
- """
388
- with self ._name_and_control_scope (name ):
389
- return self ._model_unflatten (
390
- self ._map_attr_over_dists (
391
- 'event_shape_tensor' ,
392
- dists = (self .sample_distributions (sample_shape )
393
- if sample_shape else None )))
335
+ def _event_shape_tensor (self ):
336
+ return self ._model_unflatten (
337
+ self ._map_attr_over_dists ('event_shape_tensor' ))
394
338
395
339
def sample_distributions (self , sample_shape = (), seed = None , value = None ,
396
340
name = 'sample_distributions' , ** kwargs ):
@@ -847,9 +791,9 @@ def _assert_compatible_shape(self, index, sample_shape, samples):
847
791
requested_shape , _ = self ._expand_sample_shape_to_vector (
848
792
tf .convert_to_tensor (sample_shape , dtype = tf .int32 ),
849
793
name = 'requested_shape' )
850
- actual_shape = prefer_static .shape (samples )
851
- actual_rank = prefer_static .rank_from_shape (actual_shape )
852
- requested_rank = prefer_static .rank_from_shape (requested_shape )
794
+ actual_shape = ps .shape (samples )
795
+ actual_rank = ps .rank_from_shape (actual_shape )
796
+ requested_rank = ps .rank_from_shape (requested_shape )
853
797
854
798
# We test for two properties we expect of yielded distributions:
855
799
# (1) The rank of the tensor of generated samples must be at least
@@ -1068,8 +1012,8 @@ def maybe_check_wont_broadcast(flat_xs, validate_args):
1068
1012
# Only when `validate_args` is `True` do we enforce the validation.
1069
1013
return flat_xs
1070
1014
msg = 'Broadcasting probably indicates an error in model specification.'
1071
- s = tuple (prefer_static .shape (x ) for x in flat_xs )
1072
- if all (prefer_static .is_numpy (s_ ) for s_ in s ):
1015
+ s = tuple (ps .shape (x ) for x in flat_xs )
1016
+ if all (ps .is_numpy (s_ ) for s_ in s ):
1073
1017
if not all (np .all (a == b ) for a , b in zip (s [1 :], s [:- 1 ])):
1074
1018
raise ValueError (msg )
1075
1019
return flat_xs
@@ -1092,7 +1036,7 @@ def __init__(self, jd, parameters=None, bijector_fn=None):
1092
1036
bijectors = tuple (bijector_fn (d )
1093
1037
for d in jd ._get_single_sample_distributions ())
1094
1038
i_min_event_ndims = tf .nest .map_structure (
1095
- prefer_static .size , jd .event_shape )
1039
+ ps .size , jd .event_shape )
1096
1040
f_min_event_ndims = jd ._model_unflatten ([
1097
1041
b .inverse_event_ndims (nd ) for b , nd in
1098
1042
zip (bijectors , jd ._model_flatten (i_min_event_ndims ))])
@@ -1207,9 +1151,9 @@ def _jd_log_prob_ratio(p, x, q, y, name=None):
1207
1151
"""Implements `log_prob_ratio` for tfd.JointDistribution*."""
1208
1152
with tf .name_scope (name or 'jd_log_prob_ratio' ):
1209
1153
tf .nest .assert_same_structure (x , y )
1210
- ps , _ = p .sample_distributions (value = x , seed = samplers .zeros_seed ())
1211
- qs , _ = q .sample_distributions (value = y , seed = samplers .zeros_seed ())
1212
- tf .nest .assert_same_structure (ps , qs )
1154
+ p_dists , _ = p .sample_distributions (value = x , seed = samplers .zeros_seed ())
1155
+ q_dists , _ = q .sample_distributions (value = y , seed = samplers .zeros_seed ())
1156
+ tf .nest .assert_same_structure (p_dists , q_dists )
1213
1157
log_prob_ratio_parts = nest .map_structure_up_to (
1214
- ps , log_prob_ratio .log_prob_ratio , ps , x , qs , y )
1158
+ p_dists , log_prob_ratio .log_prob_ratio , p_dists , x , q_dists , y )
1215
1159
return tf .add_n (tf .nest .flatten (log_prob_ratio_parts ))
0 commit comments