Skip to content

Commit 878f096

Browse files
davmretensorflower-gardener
authored andcommitted
STS: update one_step_predictive and impute_missing_values to expose per-timestep log probabilities.
This is strictly more flexible than the existing behavior, and supports the use of numerical root search to find quantiles of the predictive distribution. It also matches the behavior of the predictive distributions constructed by the Gibbs sampling code at `tfp.experimental.sts_gibbs.gibbs_sampler.one_step_predictive`. PiperOrigin-RevId: 384605127
1 parent 27155a5 commit 878f096

File tree

3 files changed

+89
-56
lines changed

3 files changed

+89
-56
lines changed

tensorflow_probability/python/sts/forecast.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from tensorflow_probability.python.internal import distribution_util as dist_util
2626
from tensorflow_probability.python.sts.internal import util as sts_util
2727

28+
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
29+
2830

2931
def _prefer_static_event_ndims(distribution):
3032
if distribution.event_shape.ndims is not None:
@@ -33,7 +35,18 @@ def _prefer_static_event_ndims(distribution):
3335
return tf.size(distribution.event_shape_tensor())
3436

3537

36-
def one_step_predictive(model, observed_time_series, parameter_samples):
38+
@deprecation.deprecated_arg_values(
39+
'2021-12-31',
40+
'`Predictive distributions returned by`tfp.sts.one_step_predictive` will '
41+
'soon compute per-timestep probabilities (treating timesteps as part of '
42+
'the batch shape) instead of a single probability for an entire series '
43+
'(the current approach, in which timesteps are treated as event shape). '
44+
'Please update your code to pass `timesteps_are_event_shape=False` (this '
45+
'will soon be the default) and to explicitly sum over the per-timestep log '
46+
'probabilities if this is required.',
47+
timesteps_are_event_shape=True)
48+
def one_step_predictive(model, observed_time_series, parameter_samples,
49+
timesteps_are_event_shape=True):
3750
"""Compute one-step-ahead predictive distributions for all timesteps.
3851
3952
Given samples from the posterior over parameters, return the predictive
@@ -55,11 +68,16 @@ def one_step_predictive(model, observed_time_series, parameter_samples):
5568
param.prior.batch_shape, param.prior.event_shape]) for param in
5669
model.parameters]`. This may optionally also be a map (Python `dict`) of
5770
parameter names to `Tensor` values.
71+
timesteps_are_event_shape: Deprecated, for backwards compatibility only.
72+
If `False`, the predictive distribution will return per-timestep
73+
probabilities
74+
Default value: `True`.
5875
5976
Returns:
60-
forecast_dist: a `tfd.MixtureSameFamily` instance with event shape
61-
[num_timesteps] and
62-
batch shape `concat([sample_shape, model.batch_shape])`, with
77+
predictive_dist: a `tfd.MixtureSameFamily` instance with event shape
78+
`[num_timesteps] if timesteps_are_event_shape else []` and
79+
batch shape `concat([sample_shape, model.batch_shape,
80+
[] if timesteps_are_event_shape else [num_timesteps])`, with
6381
`num_posterior_draws` mixture components. The `t`th step represents the
6482
forecast distribution `p(observed_time_series[t] |
6583
observed_time_series[0:t-1], parameter_samples)`.
@@ -168,9 +186,13 @@ def plot_one_step_predictive(observed_time_series,
168186

169187
# Squeeze dims to convert from LGSSM's event shape `[num_timesteps, 1]`
170188
# to a scalar time series.
171-
return sts_util.mix_over_posterior_draws(
189+
predictive_dist = sts_util.mix_over_posterior_draws(
172190
means=observation_means[..., 0],
173191
variances=observation_covs[..., 0, 0])
192+
if timesteps_are_event_shape:
193+
predictive_dist = tfd.Independent(
194+
predictive_dist, reinterpreted_batch_ndims=1)
195+
return predictive_dist
174196

175197

176198
def forecast(model,
@@ -383,10 +405,21 @@ def plot_forecast(observed_time_series,
383405
components_distribution=forecast_ssm)
384406

385407

408+
@deprecation.deprecated_arg_values(
409+
'2021-12-31',
410+
'`Imputed distributions returned by`tfp.sts.impute_missing_values` will '
411+
'soon compute per-timestep probabilities (treating timesteps as part of '
412+
'the batch shape) instead of a single probability for an entire series '
413+
'(the current approach, in which timesteps are treated as event shape). '
414+
'Please update your code to pass `timesteps_are_event_shape=False` (this '
415+
'will soon be the default) and to explicitly sum over the per-timestep log '
416+
'probabilities if this is required.',
417+
timesteps_are_event_shape=True)
386418
def impute_missing_values(model,
387419
observed_time_series,
388420
parameter_samples,
389-
include_observation_noise=False):
421+
include_observation_noise=False,
422+
timesteps_are_event_shape=True):
390423
"""Runs posterior inference to impute the missing values in a time series.
391424
392425
This method computes the posterior marginals `p(latent state | observations)`,
@@ -417,11 +450,17 @@ def impute_missing_values(model,
417450
values that could be *observed* at each timestep, including any i.i.d.
418451
observation noise.
419452
Default value: `False`.
453+
timesteps_are_event_shape: Deprecated, for backwards compatibility only.
454+
If `False`, the predictive distribution will return per-timestep
455+
probabilities
456+
Default value: `True`.
420457
421458
Returns:
422459
imputed_series_dist: a `tfd.MixtureSameFamily` instance with event shape
423-
[num_timesteps] and batch shape `concat([sample_shape,
424-
model.batch_shape])`, with `num_posterior_draws` mixture components.
460+
`[num_timesteps] if timesteps_are_event_shape else []` and
461+
batch shape `concat([sample_shape, model.batch_shape,
462+
[] if timesteps_are_event_shape else [num_timesteps])`, with
463+
`num_posterior_draws` mixture components.
425464
426465
#### Example
427466
@@ -497,6 +536,10 @@ def impute_missing_values(model,
497536

498537
# Squeeze dims to convert from LGSSM's event shape `[num_timesteps, 1]`
499538
# to a scalar time series.
500-
return sts_util.mix_over_posterior_draws(
539+
imputed_values_dist = sts_util.mix_over_posterior_draws(
501540
means=observation_means[..., 0],
502541
variances=observation_covs[..., 0, 0])
542+
if timesteps_are_event_shape:
543+
imputed_values_dist = tfd.Independent(
544+
imputed_values_dist, reinterpreted_batch_ndims=1)
545+
return imputed_values_dist

tensorflow_probability/python/sts/forecast_test.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def test_one_step_predictive_correctness(self):
5656
[observation_noise_scale])}
5757

5858
onestep_dist = tfp.sts.one_step_predictive(model, observed_time_series,
59+
timesteps_are_event_shape=False,
5960
parameter_samples=params)
60-
onestep_mean_, onestep_scale_ = self.evaluate(
61-
(onestep_dist.mean(), onestep_dist.stddev()))
61+
onestep_mean, onestep_scale = onestep_dist.mean(), onestep_dist.stddev()
6262

6363
# Since Seasonal is just a set of interleaved random walks, it's
6464
# straightforward to compute the forecast analytically.
@@ -80,8 +80,8 @@ def test_one_step_predictive_correctness(self):
8080
expected_onestep_scale = np.concatenate([
8181
[np.sqrt(1.**2 + observation_noise_scale**2)] * 4,
8282
[np.sqrt(observation_predictive_variance)] * 4])
83-
self.assertAllClose(onestep_mean_, expected_onestep_mean)
84-
self.assertAllClose(onestep_scale_, expected_onestep_scale)
83+
self.assertAllClose(onestep_mean, expected_onestep_mean)
84+
self.assertAllClose(onestep_scale, expected_onestep_scale)
8585

8686
def test_one_step_predictive_with_batch_shape(self):
8787
num_param_samples = 5
@@ -95,16 +95,16 @@ def test_one_step_predictive_with_batch_shape(self):
9595
for param in model.parameters]
9696

9797
onestep_dist = tfp.sts.one_step_predictive(model, observed_time_series,
98+
timesteps_are_event_shape=False,
9899
parameter_samples=prior_samples)
99100

100101
self.evaluate(tf1.global_variables_initializer())
101-
if self.use_static_shape:
102-
self.assertAllEqual(onestep_dist.batch_shape.as_list(), batch_shape)
103-
else:
104-
self.assertAllEqual(self.evaluate(onestep_dist.batch_shape_tensor()),
105-
batch_shape)
106-
onestep_mean_ = self.evaluate(onestep_dist.mean())
107-
self.assertAllEqual(onestep_mean_.shape, batch_shape + [num_timesteps])
102+
self.assertAllEqual(onestep_dist.batch_shape_tensor(),
103+
batch_shape + [num_timesteps])
104+
onestep_mean = onestep_dist.mean()
105+
self.assertAllEqual(tf.shape(onestep_mean), batch_shape + [num_timesteps])
106+
self.assertAllEqual(tf.shape(onestep_dist.log_prob(onestep_mean)),
107+
batch_shape + [num_timesteps])
108108

109109
def test_forecast_correctness(self):
110110
observed_time_series_ = np.array([1., -1., -3., 4.])
@@ -125,8 +125,6 @@ def test_forecast_correctness(self):
125125
include_observation_noise=True)
126126
forecast_mean = forecast_dist.mean()[..., 0]
127127
forecast_scale = forecast_dist.stddev()[..., 0]
128-
forecast_mean_, forecast_scale_ = self.evaluate(
129-
(forecast_mean, forecast_scale))
130128

131129
# Since Seasonal is just a set of interleaved random walks, it's
132130
# straightforward to compute the forecast analytically.
@@ -143,8 +141,8 @@ def test_forecast_correctness(self):
143141
expected_forecast_scale = np.concatenate([
144142
[np.sqrt(observation_predictive_variance)] * 4,
145143
[np.sqrt(observation_predictive_variance + drift_scale**2)] * 4])
146-
self.assertAllClose(forecast_mean_, expected_forecast_mean)
147-
self.assertAllClose(forecast_scale_, expected_forecast_scale)
144+
self.assertAllClose(forecast_mean, expected_forecast_mean)
145+
self.assertAllClose(forecast_scale, expected_forecast_scale)
148146

149147
# Also test forecasting the noise-free function.
150148
forecast_dist = tfp.sts.forecast(model, observed_time_series,
@@ -153,15 +151,13 @@ def test_forecast_correctness(self):
153151
include_observation_noise=False)
154152
forecast_mean = forecast_dist.mean()[..., 0]
155153
forecast_scale = forecast_dist.stddev()[..., 0]
156-
forecast_mean_, forecast_scale_ = self.evaluate(
157-
(forecast_mean, forecast_scale))
158154

159155
noiseless_predictive_variance = (effect_posterior_variance + drift_scale**2)
160156
expected_forecast_scale = np.concatenate([
161157
[np.sqrt(noiseless_predictive_variance)] * 4,
162158
[np.sqrt(noiseless_predictive_variance + drift_scale**2)] * 4])
163-
self.assertAllClose(forecast_mean_, expected_forecast_mean)
164-
self.assertAllClose(forecast_scale_, expected_forecast_scale)
159+
self.assertAllClose(forecast_mean, expected_forecast_mean)
160+
self.assertAllClose(forecast_scale, expected_forecast_scale)
165161

166162
def test_forecast_from_hmc(self):
167163
# test that we can directly plug in the output of an HMC chain as
@@ -220,15 +216,9 @@ def test_forecast_with_batch_shape(self):
220216
num_steps_forecast=num_steps_forecast)
221217

222218
self.evaluate(tf1.global_variables_initializer())
223-
if self.use_static_shape:
224-
self.assertAllEqual(forecast_dist.batch_shape.as_list(), batch_shape)
225-
else:
226-
self.assertAllEqual(self.evaluate(forecast_dist.batch_shape_tensor()),
227-
batch_shape)
228-
forecast_mean = forecast_dist.mean()[..., 0]
229-
forecast_mean_ = self.evaluate(forecast_mean)
230-
self.assertAllEqual(forecast_mean_.shape,
231-
batch_shape + [num_steps_forecast])
219+
self.assertAllEqual(forecast_dist.batch_shape_tensor(), batch_shape)
220+
self.assertAllEqual(tf.shape(forecast_dist.mean()),
221+
batch_shape + [num_steps_forecast, 1])
232222

233223
def test_methods_handle_masked_inputs(self):
234224
num_param_samples = 5
@@ -268,6 +258,7 @@ def test_methods_handle_masked_inputs(self):
268258
self.assertTrue(np.all(np.isfinite(onestep_stddev_)))
269259

270260
def test_impute_missing(self):
261+
num_timesteps = 7
271262
time_series_with_nans = self._build_tensor(
272263
[-1., 1., np.nan, 2.4, np.nan, np.nan, 2.])
273264
observed_time_series = tfp.sts.MaskedTimeSeries(
@@ -288,19 +279,21 @@ def test_impute_missing(self):
288279
parameter_samples = {'observation_noise_scale': [noise_scale],
289280
'seasonal/_drift_scale': [drift_scale]}
290281
imputed_series_dist = tfp.sts.impute_missing_values(
291-
model, observed_time_series, parameter_samples)
282+
model, observed_time_series, parameter_samples,
283+
timesteps_are_event_shape=False)
292284
imputed_noisy_series_dist = tfp.sts.impute_missing_values(
293285
model, observed_time_series, parameter_samples,
286+
timesteps_are_event_shape=False,
294287
include_observation_noise=True)
288+
self.assertAllEqual(imputed_noisy_series_dist.batch_shape_tensor(),
289+
[num_timesteps])
295290

296291
# Compare imputed mean to expected mean.
297-
mean_, stddev_ = self.evaluate([imputed_series_dist.mean(),
298-
imputed_series_dist.stddev()])
299-
noisy_mean_, noisy_stddev_ = self.evaluate([
300-
imputed_noisy_series_dist.mean(),
301-
imputed_noisy_series_dist.stddev()])
302-
self.assertAllClose(mean_, [-1., 1., 2., 2.4, -1., 1., 2.], atol=1e-2)
303-
self.assertAllClose(mean_, noisy_mean_, atol=1e-2)
292+
mean, stddev = imputed_series_dist.mean(), imputed_series_dist.stddev()
293+
noisy_mean, noisy_stddev = [imputed_noisy_series_dist.mean(),
294+
imputed_noisy_series_dist.stddev()]
295+
self.assertAllClose(mean, [-1., 1., 2., 2.4, -1., 1., 2.], atol=1e-2)
296+
self.assertAllClose(mean, noisy_mean, atol=1e-2)
304297

305298
# Compare imputed stddevs to expected stddevs.
306299
drift_plus_noise_scale = np.sqrt(noise_scale**2 + drift_scale**2)
@@ -311,9 +304,9 @@ def test_impute_missing(self):
311304
drift_plus_noise_scale,
312305
drift_plus_noise_scale,
313306
noise_scale])
314-
self.assertAllClose(stddev_, expected_stddev, atol=1e-2)
315-
self.assertAllClose(noisy_stddev_,
316-
np.sqrt(stddev_**2 + noise_scale**2), atol=1e-2)
307+
self.assertAllClose(stddev, expected_stddev, atol=1e-2)
308+
self.assertAllClose(noisy_stddev,
309+
tf.sqrt(stddev**2 + noise_scale**2), atol=1e-2)
317310

318311
def _build_tensor(self, ndarray, dtype=None):
319312
"""Convert a numpy array to a TF placeholder.

tensorflow_probability/python/sts/internal/util.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def mix_over_posterior_draws(means, variances):
376376
Returns:
377377
mixture_dist: `tfd.MixtureSameFamily(tfd.Independent(tfd.Normal))` instance
378378
representing a uniform mixture over the posterior samples, with
379-
`batch_shape = ...` and `event_shape = [num_timesteps]`.
379+
`batch_shape = [..., num_timesteps]` and `event_shape = []`.
380380
381381
"""
382382
# The inputs `means`, `variances` have shape
@@ -387,19 +387,16 @@ def mix_over_posterior_draws(means, variances):
387387
# [num_timesteps]])`
388388
# Because MixtureSameFamily mixes over the rightmost batch dimension,
389389
# we need to move the `num_posterior_draws` dimension to be rightmost
390-
# in the batch shape. This requires use of `Independent` (to preserve
391-
# `num_timesteps` as part of the event shape) and `move_dimension`.
390+
# in the batch shape.
392391
# TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
393392
# arbitrary axis, and eliminate `move_dimension` calls here.
394393

395394
with tf.name_scope('mix_over_posterior_draws'):
396395
num_posterior_draws = ps.shape(means)[0]
397396

398-
component_observations = tfd.Independent(
399-
distribution=tfd.Normal(
400-
loc=dist_util.move_dimension(means, 0, -2),
401-
scale=tf.sqrt(dist_util.move_dimension(variances, 0, -2))),
402-
reinterpreted_batch_ndims=1)
397+
component_observations = tfd.Normal(
398+
loc=dist_util.move_dimension(means, 0, -1),
399+
scale=tf.sqrt(dist_util.move_dimension(variances, 0, -1)))
403400

404401
return tfd.MixtureSameFamily(
405402
mixture_distribution=tfd.Categorical(

0 commit comments

Comments
 (0)