@@ -56,9 +56,9 @@ def test_one_step_predictive_correctness(self):
56
56
[observation_noise_scale ])}
57
57
58
58
onestep_dist = tfp .sts .one_step_predictive (model , observed_time_series ,
59
+ timesteps_are_event_shape = False ,
59
60
parameter_samples = params )
60
- onestep_mean_ , onestep_scale_ = self .evaluate (
61
- (onestep_dist .mean (), onestep_dist .stddev ()))
61
+ onestep_mean , onestep_scale = onestep_dist .mean (), onestep_dist .stddev ()
62
62
63
63
# Since Seasonal is just a set of interleaved random walks, it's
64
64
# straightforward to compute the forecast analytically.
@@ -80,8 +80,8 @@ def test_one_step_predictive_correctness(self):
80
80
expected_onestep_scale = np .concatenate ([
81
81
[np .sqrt (1. ** 2 + observation_noise_scale ** 2 )] * 4 ,
82
82
[np .sqrt (observation_predictive_variance )] * 4 ])
83
- self .assertAllClose (onestep_mean_ , expected_onestep_mean )
84
- self .assertAllClose (onestep_scale_ , expected_onestep_scale )
83
+ self .assertAllClose (onestep_mean , expected_onestep_mean )
84
+ self .assertAllClose (onestep_scale , expected_onestep_scale )
85
85
86
86
def test_one_step_predictive_with_batch_shape (self ):
87
87
num_param_samples = 5
@@ -95,16 +95,16 @@ def test_one_step_predictive_with_batch_shape(self):
95
95
for param in model .parameters ]
96
96
97
97
onestep_dist = tfp .sts .one_step_predictive (model , observed_time_series ,
98
+ timesteps_are_event_shape = False ,
98
99
parameter_samples = prior_samples )
99
100
100
101
self .evaluate (tf1 .global_variables_initializer ())
101
- if self .use_static_shape :
102
- self .assertAllEqual (onestep_dist .batch_shape .as_list (), batch_shape )
103
- else :
104
- self .assertAllEqual (self .evaluate (onestep_dist .batch_shape_tensor ()),
105
- batch_shape )
106
- onestep_mean_ = self .evaluate (onestep_dist .mean ())
107
- self .assertAllEqual (onestep_mean_ .shape , batch_shape + [num_timesteps ])
102
+ self .assertAllEqual (onestep_dist .batch_shape_tensor (),
103
+ batch_shape + [num_timesteps ])
104
+ onestep_mean = onestep_dist .mean ()
105
+ self .assertAllEqual (tf .shape (onestep_mean ), batch_shape + [num_timesteps ])
106
+ self .assertAllEqual (tf .shape (onestep_dist .log_prob (onestep_mean )),
107
+ batch_shape + [num_timesteps ])
108
108
109
109
def test_forecast_correctness (self ):
110
110
observed_time_series_ = np .array ([1. , - 1. , - 3. , 4. ])
@@ -125,8 +125,6 @@ def test_forecast_correctness(self):
125
125
include_observation_noise = True )
126
126
forecast_mean = forecast_dist .mean ()[..., 0 ]
127
127
forecast_scale = forecast_dist .stddev ()[..., 0 ]
128
- forecast_mean_ , forecast_scale_ = self .evaluate (
129
- (forecast_mean , forecast_scale ))
130
128
131
129
# Since Seasonal is just a set of interleaved random walks, it's
132
130
# straightforward to compute the forecast analytically.
@@ -143,8 +141,8 @@ def test_forecast_correctness(self):
143
141
expected_forecast_scale = np .concatenate ([
144
142
[np .sqrt (observation_predictive_variance )] * 4 ,
145
143
[np .sqrt (observation_predictive_variance + drift_scale ** 2 )] * 4 ])
146
- self .assertAllClose (forecast_mean_ , expected_forecast_mean )
147
- self .assertAllClose (forecast_scale_ , expected_forecast_scale )
144
+ self .assertAllClose (forecast_mean , expected_forecast_mean )
145
+ self .assertAllClose (forecast_scale , expected_forecast_scale )
148
146
149
147
# Also test forecasting the noise-free function.
150
148
forecast_dist = tfp .sts .forecast (model , observed_time_series ,
@@ -153,15 +151,13 @@ def test_forecast_correctness(self):
153
151
include_observation_noise = False )
154
152
forecast_mean = forecast_dist .mean ()[..., 0 ]
155
153
forecast_scale = forecast_dist .stddev ()[..., 0 ]
156
- forecast_mean_ , forecast_scale_ = self .evaluate (
157
- (forecast_mean , forecast_scale ))
158
154
159
155
noiseless_predictive_variance = (effect_posterior_variance + drift_scale ** 2 )
160
156
expected_forecast_scale = np .concatenate ([
161
157
[np .sqrt (noiseless_predictive_variance )] * 4 ,
162
158
[np .sqrt (noiseless_predictive_variance + drift_scale ** 2 )] * 4 ])
163
- self .assertAllClose (forecast_mean_ , expected_forecast_mean )
164
- self .assertAllClose (forecast_scale_ , expected_forecast_scale )
159
+ self .assertAllClose (forecast_mean , expected_forecast_mean )
160
+ self .assertAllClose (forecast_scale , expected_forecast_scale )
165
161
166
162
def test_forecast_from_hmc (self ):
167
163
# test that we can directly plug in the output of an HMC chain as
@@ -220,15 +216,9 @@ def test_forecast_with_batch_shape(self):
220
216
num_steps_forecast = num_steps_forecast )
221
217
222
218
self .evaluate (tf1 .global_variables_initializer ())
223
- if self .use_static_shape :
224
- self .assertAllEqual (forecast_dist .batch_shape .as_list (), batch_shape )
225
- else :
226
- self .assertAllEqual (self .evaluate (forecast_dist .batch_shape_tensor ()),
227
- batch_shape )
228
- forecast_mean = forecast_dist .mean ()[..., 0 ]
229
- forecast_mean_ = self .evaluate (forecast_mean )
230
- self .assertAllEqual (forecast_mean_ .shape ,
231
- batch_shape + [num_steps_forecast ])
219
+ self .assertAllEqual (forecast_dist .batch_shape_tensor (), batch_shape )
220
+ self .assertAllEqual (tf .shape (forecast_dist .mean ()),
221
+ batch_shape + [num_steps_forecast , 1 ])
232
222
233
223
def test_methods_handle_masked_inputs (self ):
234
224
num_param_samples = 5
@@ -268,6 +258,7 @@ def test_methods_handle_masked_inputs(self):
268
258
self .assertTrue (np .all (np .isfinite (onestep_stddev_ )))
269
259
270
260
def test_impute_missing (self ):
261
+ num_timesteps = 7
271
262
time_series_with_nans = self ._build_tensor (
272
263
[- 1. , 1. , np .nan , 2.4 , np .nan , np .nan , 2. ])
273
264
observed_time_series = tfp .sts .MaskedTimeSeries (
@@ -288,19 +279,21 @@ def test_impute_missing(self):
288
279
parameter_samples = {'observation_noise_scale' : [noise_scale ],
289
280
'seasonal/_drift_scale' : [drift_scale ]}
290
281
imputed_series_dist = tfp .sts .impute_missing_values (
291
- model , observed_time_series , parameter_samples )
282
+ model , observed_time_series , parameter_samples ,
283
+ timesteps_are_event_shape = False )
292
284
imputed_noisy_series_dist = tfp .sts .impute_missing_values (
293
285
model , observed_time_series , parameter_samples ,
286
+ timesteps_are_event_shape = False ,
294
287
include_observation_noise = True )
288
+ self .assertAllEqual (imputed_noisy_series_dist .batch_shape_tensor (),
289
+ [num_timesteps ])
295
290
296
291
# Compare imputed mean to expected mean.
297
- mean_ , stddev_ = self .evaluate ([imputed_series_dist .mean (),
298
- imputed_series_dist .stddev ()])
299
- noisy_mean_ , noisy_stddev_ = self .evaluate ([
300
- imputed_noisy_series_dist .mean (),
301
- imputed_noisy_series_dist .stddev ()])
302
- self .assertAllClose (mean_ , [- 1. , 1. , 2. , 2.4 , - 1. , 1. , 2. ], atol = 1e-2 )
303
- self .assertAllClose (mean_ , noisy_mean_ , atol = 1e-2 )
292
+ mean , stddev = imputed_series_dist .mean (), imputed_series_dist .stddev ()
293
+ noisy_mean , noisy_stddev = [imputed_noisy_series_dist .mean (),
294
+ imputed_noisy_series_dist .stddev ()]
295
+ self .assertAllClose (mean , [- 1. , 1. , 2. , 2.4 , - 1. , 1. , 2. ], atol = 1e-2 )
296
+ self .assertAllClose (mean , noisy_mean , atol = 1e-2 )
304
297
305
298
# Compare imputed stddevs to expected stddevs.
306
299
drift_plus_noise_scale = np .sqrt (noise_scale ** 2 + drift_scale ** 2 )
@@ -311,9 +304,9 @@ def test_impute_missing(self):
311
304
drift_plus_noise_scale ,
312
305
drift_plus_noise_scale ,
313
306
noise_scale ])
314
- self .assertAllClose (stddev_ , expected_stddev , atol = 1e-2 )
315
- self .assertAllClose (noisy_stddev_ ,
316
- np .sqrt (stddev_ ** 2 + noise_scale ** 2 ), atol = 1e-2 )
307
+ self .assertAllClose (stddev , expected_stddev , atol = 1e-2 )
308
+ self .assertAllClose (noisy_stddev ,
309
+ tf .sqrt (stddev ** 2 + noise_scale ** 2 ), atol = 1e-2 )
317
310
318
311
def _build_tensor (self , ndarray , dtype = None ):
319
312
"""Convert a numpy array to a TF placeholder.
0 commit comments