Skip to content

Commit f7eaefe

Browse files
davmretensorflower-gardener
authored andcommitted
Fix an embarassing off-by-one error in STS anomaly detection predictive distributions.
I'd understood `tfp.sts.one_step_predictive` to return a distribution over timesteps `1:T + 1`, so the previous code went to a lot of effort to send it a series padded with an extra initial step in order to get results for steps `0:T` (matching the Gibbs predictive dist). But it turns out that I was wrong, and this padding is not only unnecessary but actually screws things up. This change also fixes a dtype issue with the Gibbs predictive distribution that came up in testing the above fix. PiperOrigin-RevId: 388789203
1 parent 34934d3 commit f7eaefe

File tree

4 files changed

+34
-23
lines changed

4 files changed

+34
-23
lines changed

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ def one_step_predictive(model,
340340
distribution of each timestep given previous timesteps.
341341
"""
342342
dtype = dtype_util.common_dtype([
343-
posterior_samples.level_scale.dtype,
344-
posterior_samples.observation_noise_scale.dtype,
345-
posterior_samples.level.dtype,
343+
posterior_samples.level_scale,
344+
posterior_samples.observation_noise_scale,
345+
posterior_samples.level,
346346
original_mean,
347347
original_scale], dtype_hint=tf.float32)
348348
num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def reshape_chain_and_sample(x):
202202

203203
@parameterized.named_parameters(
204204
{'testcase_name': 'float32_xla', 'dtype': tf.float32, 'use_xla': True},
205-
{'testcase_name': 'float16', 'dtype': tf.float16, 'use_xla': False})
205+
{'testcase_name': 'float64', 'dtype': tf.float64, 'use_xla': False})
206206
def test_end_to_end_prediction_works_and_is_deterministic(
207207
self, dtype, use_xla):
208208
if not tf.executing_eagerly():
@@ -211,7 +211,8 @@ def test_end_to_end_prediction_works_and_is_deterministic(
211211
model, observed_time_series, is_missing = self._build_test_model(
212212
num_timesteps=5,
213213
batch_shape=[3],
214-
prior_class=gibbs_sampler.XLACompilableInverseGamma)
214+
prior_class=gibbs_sampler.XLACompilableInverseGamma,
215+
dtype=dtype)
215216

216217
@tf.function(jit_compile=use_xla)
217218
def do_sampling(observed_time_series, is_missing):

tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_lib.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tensorflow_probability.python.internal import prefer_static as ps
2525
from tensorflow_probability.python.sts import regularization
2626
from tensorflow_probability.python.sts.forecast import one_step_predictive
27-
from tensorflow_probability.python.sts.internal import missing_values_util
2827
from tensorflow_probability.python.sts.internal import seasonality_util
2928
from tensorflow_probability.python.sts.internal import util as sts_util
3029

@@ -177,24 +176,8 @@ def _detect_anomalies_inner(observed_time_series,
177176
predictive_dist = gibbs_sampler.one_step_predictive(model,
178177
posterior_samples)
179178
else:
180-
# Build the filtering predictive distribution.
181-
# First, pad the time series with an unobserved point at time -1, so that
182-
# we get a 'prediction' (which will just be the prior) at time 0, matching
183-
# the shape of the input (and of the Gibbs predictive distribution).
184-
is_missing = observed_time_series.is_missing
185-
if is_missing is None:
186-
is_missing = tf.zeros(tf.shape(observed_time_series.time_series)[:-1],
187-
dtype=tf.bool)
188-
initial_value = observed_time_series.time_series[..., 0:1, :]
189-
padded_series = missing_values_util.MaskedTimeSeries(
190-
time_series=tf.concat(
191-
[initial_value,
192-
observed_time_series.time_series[..., :-1, :]], axis=-2),
193-
is_missing=tf.concat(
194-
[tf.ones(tf.shape(initial_value)[:-1], dtype=tf.bool),
195-
is_missing[..., :-1]], axis=-1))
196179
predictive_dist = one_step_predictive(model,
197-
padded_series,
180+
observed_time_series,
198181
timesteps_are_event_shape=False,
199182
parameter_samples=parameter_samples)
200183

tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,33 @@ def test_constant_series(self):
128128
tf.ones_like(predictions.mean),
129129
atol=0.1)
130130

131+
@parameterized.named_parameters(('', False),
132+
('_gibbs_predictive', True))
133+
def test_predictions_align_with_series(self, use_gibbs_predictive_dist):
134+
np.random.seed(0)
135+
# Simulate data with very clear daily and hourly effects, so that an
136+
# off-by-one error will almost certainly lead to out-of-bounds predictions.
137+
daily_effects = [100., 0., 20., -50., -100., -20., 70.]
138+
hourly_effects = [
139+
20., 0., 10., -10., 0., -20., -10., -30., -15., -5., -10., 0.] * 2
140+
effects = [daily_effects[(t // 24) % 7] + hourly_effects[t % 24]
141+
for t in range(24 * 7 * 2)]
142+
series = pd.Series(effects + np.random.randn(len(effects)),
143+
index=pd.date_range('2020-01-01',
144+
periods=len(effects),
145+
freq=pd.DateOffset(hours=1)))
146+
predictions = anomaly_detection.detect_anomalies(
147+
series,
148+
seed=test_util.test_seed(sampler_type='stateless'),
149+
num_samples=100,
150+
num_warmup_steps=50,
151+
use_gibbs_predictive_dist=use_gibbs_predictive_dist,
152+
jit_compile=False)
153+
# An off-by-one error in the predictive distribution would generate
154+
# anomalies at most steps.
155+
num_anomalies = tf.reduce_sum(tf.cast(predictions.is_anomaly, tf.int32))
156+
self.assertLessEqual(self.evaluate(num_anomalies), 5)
157+
131158

132159
if __name__ == '__main__':
133160
tf.test.main()

0 commit comments

Comments
 (0)