@@ -81,27 +81,36 @@ def trace_distributions_and_values(dist, sample_shape, seed, value=None):
81
81
"""Draws a sample, and traces both the distribution and sampled value."""
82
82
if value is None :
83
83
value = dist .sample (sample_shape , seed = seed )
84
+ elif tf .nest .is_nested (dist .dtype ) and any (
85
+ v is None for v in tf .nest .flatten (value )):
86
+ # TODO(siege): This is making an assumption that nested dtype => partial
87
+ # value support, which is not necessarily reasonable.
88
+ value = dist .sample (sample_shape , seed = seed , value = value )
84
89
return ValueWithTrace (value = value , traced = (dist , value ))
85
90
86
91
87
92
def trace_distributions_only (dist , sample_shape , seed , value = None ):
88
93
"""Draws a sample, and traces the sampled value."""
89
- if value is None :
90
- value = dist .sample (sample_shape , seed = seed )
91
- return ValueWithTrace (value = value , traced = dist )
94
+ ret = trace_distributions_and_values (dist , sample_shape , seed , value )
95
+ return ret ._replace (traced = ret .traced [0 ])
92
96
93
97
94
98
def trace_values_only (dist , sample_shape , seed , value = None ):
95
99
"""Draws a sample, and traces the sampled value."""
96
- if value is None :
97
- value = dist .sample (sample_shape , seed = seed )
98
- return ValueWithTrace (value = value , traced = value )
100
+ ret = trace_distributions_and_values (dist , sample_shape , seed , value )
101
+ return ret ._replace (traced = ret .traced [1 ])
99
102
100
103
101
104
def trace_values_and_log_probs (dist , sample_shape , seed , value = None ):
102
105
"""Draws a sample, and traces both the sampled value and its log density."""
103
106
if value is None :
104
107
value , lp = dist .experimental_sample_and_log_prob (sample_shape , seed = seed )
108
+ elif tf .nest .is_nested (dist .dtype ) and any (
109
+ v is None for v in tf .nest .flatten (value )):
110
+ # TODO(siege): This is making an assumption that nested dtype => partial
111
+ # value support, which is not necessarily reasonable.
112
+ value , lp = dist .experimental_sample_and_log_prob (
113
+ sample_shape , seed = seed , value = value )
105
114
else :
106
115
lp = dist .log_prob (value )
107
116
return ValueWithTrace (value = value , traced = (value , lp ))
@@ -210,7 +219,9 @@ class JointDistribution(distribution_lib.Distribution):
210
219
- `_model_coroutine`: A generator that yields a sequence of
211
220
`tfd.Distribution`-like instances.
212
221
213
- - `_model_flatten`: takes a structured input and returns a sequence.
222
+ - `_model_flatten`: takes a structured input and returns a sequence. The
223
+ sequence order must match the order distributions are yielded from
224
+ `_model_coroutine`.
214
225
215
226
- `_model_unflatten`: takes a sequence and returns a structure matching the
216
227
semantics of the `JointDistribution` subclass.
@@ -613,33 +624,14 @@ def _map_attr_over_dists(self, attr, dists=None):
613
624
if dists is None else dists )
614
625
return (getattr (d , attr )() for d in dists )
615
626
616
- def _sanitize_value (self , value ):
617
- """Ensures `value` matches `self.dtype` with `Tensor` or `None` elements."""
618
- if value is None :
619
- return value
620
-
621
- if len (value ) < len (self .dtype ):
622
- # Fill in missing entries with `None`.
623
- if hasattr (self .dtype , 'keys' ):
624
- value = {k : value .get (k , None ) for k in self .dtype .keys ()}
625
- else : # dtype is a sequence.
626
- value = [value [i ] if i < len (value ) else None
627
- for i in range (len (self .dtype ))]
628
-
629
- value = nest_util .cast_structure (value , self .dtype )
630
- return nest .map_structure_up_to (
631
- self .dtype ,
632
- lambda x , d : x if x is None else tf .convert_to_tensor (x , dtype_hint = d ),
633
- value , self .dtype )
634
-
635
627
def _resolve_value (self , * args , allow_partially_specified = False , ** kwargs ):
636
628
"""Resolves a `value` structure from user-passed arguments."""
637
629
value = kwargs .pop ('value' , None )
638
630
if not (args or kwargs ):
639
- # Fast path when `value` is the only kwarg. The case where `value` is
640
- # passed as a positional arg is handled by `_resolve_value_from_args`
641
- # below.
642
- return self . _sanitize_value (value )
631
+ # Fast path when `value` is the only kwarg. The case where `value` is
632
+ # passed as a positional arg is handled by `_resolve_value_from_args`
633
+ # below.
634
+ return _sanitize_value (self , value )
643
635
elif value is not None :
644
636
raise ValueError ('Supplied both `value` and keyword '
645
637
'arguments to parameterize sampling. Supplied keyword '
@@ -665,7 +657,7 @@ def _resolve_value(self, *args, allow_partially_specified=False, **kwargs):
665
657
'Found unexpected keyword arguments. Distribution names '
666
658
'are\n {}\n but received\n {}\n These names were '
667
659
'invalid:\n {}' .format (dist_name_str , kwarg_names , unmatched_str ))
668
- return self . _sanitize_value (value )
660
+ return _sanitize_value (self , value )
669
661
670
662
def _call_execute_model (self ,
671
663
sample_shape = (),
@@ -793,17 +785,7 @@ def _execute_model(self,
793
785
value_at_index = None
794
786
if (value is not None and len (value ) > index and
795
787
value [index ] is not None ):
796
-
797
- def convert_tree_to_tensor (x , dtype_hint ):
798
- return tf .convert_to_tensor (x , dtype_hint = dtype_hint )
799
-
800
- # This signature does not allow kwarg names. Applies
801
- # `convert_to_tensor` on the next value.
802
- value_at_index = nest .map_structure_up_to (
803
- actual_distribution .dtype , # shallow_tree
804
- convert_tree_to_tensor , # func
805
- value [index ], # x
806
- actual_distribution .dtype ) # dtype_hint
788
+ value_at_index = _sanitize_value (actual_distribution , value [index ])
807
789
try :
808
790
next_value , traced_values = sample_and_trace_fn (
809
791
actual_distribution ,
@@ -1175,6 +1157,46 @@ def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs):
1175
1157
y , event_ndims , _jd_conditioning = y , ** kwargs )
1176
1158
1177
1159
1160
+ def _sanitize_value (distribution , value ):
1161
+ """Ensures `value` matches `distribution.dtype`, adding `None`s as needed."""
1162
+ if value is None :
1163
+ return value
1164
+
1165
+ if not tf .nest .is_nested (distribution .dtype ):
1166
+ return tf .convert_to_tensor (value , dtype_hint = distribution .dtype )
1167
+
1168
+ if len (value ) < len (distribution .dtype ):
1169
+ # Fill in missing entries with `None`.
1170
+ if hasattr (distribution .dtype , 'keys' ):
1171
+ value = {k : value .get (k , None ) for k in distribution .dtype .keys ()}
1172
+ else : # dtype is a sequence.
1173
+ value = [value [i ] if i < len (value ) else None
1174
+ for i in range (len (distribution .dtype ))]
1175
+
1176
+ value = nest_util .cast_structure (value , distribution .dtype )
1177
+ jdlike_attrs = [
1178
+ '_get_single_sample_distributions' ,
1179
+ '_model_flatten' ,
1180
+ '_model_unflatten' ,
1181
+ ]
1182
+ if all (hasattr (distribution , attr ) for attr in jdlike_attrs ):
1183
+ flat_dists = distribution ._get_single_sample_distributions ()
1184
+ flat_value = distribution ._model_flatten (value )
1185
+ flat_value = map (_sanitize_value , flat_dists , flat_value )
1186
+ return distribution ._model_unflatten (flat_value )
1187
+ else :
1188
+ # A joint distribution that isn't tfd.JointDistribution-like; assume it has
1189
+ # some reasonable dtype semantics. We can't use this for
1190
+ # tfd.JointDistribution because we might have a None standing in for a
1191
+ # sub-tree (e.g. consider omitting a nested JD).
1192
+ return nest .map_structure_up_to (
1193
+ distribution .dtype ,
1194
+ lambda x , d : x if x is None else tf .convert_to_tensor (x , dtype_hint = d ),
1195
+ value ,
1196
+ distribution .dtype ,
1197
+ )
1198
+
1199
+
1178
1200
@log_prob_ratio .RegisterLogProbRatio (JointDistribution )
1179
1201
def _jd_log_prob_ratio (p , x , q , y , name = None ):
1180
1202
"""Implements `log_prob_ratio` for tfd.JointDistribution*."""
0 commit comments