@@ -449,6 +449,7 @@ def _log_prob(self, value):
449
449
- `_default_event_space_bijector`.
450
450
- `_parameter_properties` (to support automatic batch shape derivation,
451
451
batch slicing and other features).
452
+ - `_sample_and_log_prob`.
452
453
453
454
Note that subclasses of existing Distributions that redefine `__init__` do
454
455
*not* automatically inherit
@@ -1166,22 +1167,21 @@ def _sample_n(self, n, seed=None, **kwargs):
1166
1167
raise NotImplementedError ('sample_n is not implemented: {}' .format (
1167
1168
type (self ).__name__ ))
1168
1169
1169
- def _call_sample_n (self , sample_shape , seed , name , ** kwargs ):
1170
+ def _call_sample_n (self , sample_shape , seed , ** kwargs ):
1170
1171
"""Wrapper around _sample_n."""
1171
- with self ._name_and_control_scope (name ):
1172
- if JAX_MODE and seed is None :
1173
- raise ValueError ('Must provide JAX PRNGKey as `dist.sample(seed=.)`' )
1174
- sample_shape = ps .convert_to_shape_tensor (
1175
- ps .cast (sample_shape , tf .int32 ), name = 'sample_shape' )
1176
- sample_shape , n = self ._expand_sample_shape_to_vector (
1177
- sample_shape , 'sample_shape' )
1178
- samples = self ._sample_n (
1179
- n , seed = seed () if callable (seed ) else seed , ** kwargs )
1180
- batch_event_shape = ps .shape (samples )[1 :]
1181
- final_shape = ps .concat ([sample_shape , batch_event_shape ], 0 )
1182
- samples = tf .reshape (samples , final_shape )
1183
- samples = self ._set_sample_static_shape (samples , sample_shape )
1184
- return samples
1172
+ if JAX_MODE and seed is None :
1173
+ raise ValueError ('Must provide JAX PRNGKey as `dist.sample(seed=.)`' )
1174
+ sample_shape = ps .convert_to_shape_tensor (
1175
+ ps .cast (sample_shape , tf .int32 ), name = 'sample_shape' )
1176
+ sample_shape , n = self ._expand_sample_shape_to_vector (
1177
+ sample_shape , 'sample_shape' )
1178
+ samples = self ._sample_n (
1179
+ n , seed = seed () if callable (seed ) else seed , ** kwargs )
1180
+ batch_event_shape = ps .shape (samples )[1 :]
1181
+ final_shape = ps .concat ([sample_shape , batch_event_shape ], 0 )
1182
+ samples = tf .reshape (samples , final_shape )
1183
+ samples = self ._set_sample_static_shape (samples , sample_shape )
1184
+ return samples
1185
1185
1186
1186
def sample (self , sample_shape = (), seed = None , name = 'sample' , ** kwargs ):
1187
1187
"""Generate samples of the specified shape.
@@ -1198,7 +1198,62 @@ def sample(self, sample_shape=(), seed=None, name='sample', **kwargs):
1198
1198
Returns:
1199
1199
samples: a `Tensor` with prepended dimensions `sample_shape`.
1200
1200
"""
1201
- return self ._call_sample_n (sample_shape , seed , name , ** kwargs )
1201
+ with self ._name_and_control_scope (name ):
1202
+ return self ._call_sample_n (sample_shape , seed , ** kwargs )
1203
+
1204
+ def _call_sample_and_log_prob (self , sample_shape , seed , ** kwargs ):
1205
+ """Wrapper around `_sample_and_log_prob`."""
1206
+ if hasattr (self , '_sample_and_log_prob' ):
1207
+ sample_shape = ps .convert_to_shape_tensor (
1208
+ ps .cast (sample_shape , tf .int32 ), name = 'sample_shape' )
1209
+ return self ._sample_and_log_prob (
1210
+ distribution_util .expand_to_vector (
1211
+ sample_shape , tensor_name = 'sample_shape' ),
1212
+ seed = seed , ** kwargs )
1213
+
1214
+ # Naive default implementation. This calls private, rather than public,
1215
+ # methods, to avoid duplicating the name_and_control_scope.
1216
+ value = self ._call_sample_n (sample_shape , seed = seed , ** kwargs )
1217
+ if hasattr (self , '_log_prob' ):
1218
+ log_prob = self ._log_prob (value , ** kwargs )
1219
+ elif hasattr (self , '_prob' ):
1220
+ log_prob = tf .math .log (self ._prob (value , ** kwargs ))
1221
+ else :
1222
+ raise NotImplementedError ('log_prob is not implemented: {}' .format (
1223
+ type (self ).__name__ ))
1224
+ return value , log_prob
1225
+
1226
+ def experimental_sample_and_log_prob (self , sample_shape = (), seed = None ,
1227
+ name = 'sample_and_log_prob' , ** kwargs ):
1228
+ """Samples from this distribution and returns the log density of the sample.
1229
+
1230
+ The default implementation simply calls `sample` and `log_prob`:
1231
+
1232
+ ```
1233
+ def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
1234
+ x = self.sample(sample_shape=sample_shape, seed=seed, **kwargs)
1235
+ return x, self.log_prob(x, **kwargs)
1236
+ ```
1237
+
1238
+ However, some subclasses may provide more efficient and/or numerically
1239
+ stable implementations.
1240
+
1241
+ Args:
1242
+ sample_shape: integer `Tensor` desired shape of samples to draw.
1243
+ Default value: `()`.
1244
+ seed: Python integer or `tfp.util.SeedStream` instance, for seeding PRNG.
1245
+ Default value: `None`.
1246
+ name: name to give to the op.
1247
+ Default value: `'sample_and_log_prob'`.
1248
+ **kwargs: Named arguments forwarded to subclass implementation.
1249
+ Returns:
1250
+ samples: a `Tensor`, or structure of `Tensor`s, with prepended dimensions
1251
+ `sample_shape`.
1252
+ log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
1253
+ values of type `self.dtype`.
1254
+ """
1255
+ with self ._name_and_control_scope (name ):
1256
+ return self ._call_sample_and_log_prob (sample_shape , seed = seed , ** kwargs )
1202
1257
1203
1258
def _call_log_prob (self , value , name , ** kwargs ):
1204
1259
"""Wrapper around _log_prob."""
0 commit comments