Skip to content

Commit 7b214b9

Browse files
davmretensorflower-gardener
authored andcommitted
Add plot_predictions utility for STS anomaly detection.
PiperOrigin-RevId: 388953690
1 parent 2750050 commit 7b214b9

File tree

4 files changed

+102
-0
lines changed

4 files changed

+102
-0
lines changed

tensorflow_probability/python/sts/anomaly_detection/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ py_test(
6262
srcs_version = "PY3",
6363
deps = [
6464
# absl/testing:parameterized dep,
65+
# matplotlib dep,
6566
# numpy dep,
6667
# pandas dep,
6768
# tensorflow dep,

tensorflow_probability/python/sts/anomaly_detection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from tensorflow_probability.python.internal import all_util
1818
from tensorflow_probability.python.sts.anomaly_detection.anomaly_detection_lib import detect_anomalies
19+
from tensorflow_probability.python.sts.anomaly_detection.anomaly_detection_lib import plot_predictions
1920
from tensorflow_probability.python.sts.anomaly_detection.anomaly_detection_lib import PredictionOutput
2021

2122
_allowed_symbols = [
2223
'detect_anomalies',
24+
'plot_predictions',
2325
'PredictionOutput'
2426
]
2527

tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_lib.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Utilities for anomaly detection with STS models and Gibbs sampling."""
1616

1717
import collections
18+
import datetime
1819

20+
import numpy as np
1921
import tensorflow.compat.v2 as tf
2022

2123
from tensorflow_probability.python import distributions as tfd
@@ -31,6 +33,7 @@
3133
'PredictionOutput',
3234
'detect_anomalies',
3335
'compute_predictive_bounds',
36+
'plot_predictions'
3437
]
3538

3639

@@ -290,3 +293,82 @@ def compute_predictive_bounds(predictive_dist, anomaly_threshold=0.01):
290293
low=predictive_mean - 100 * predictive_stddev,
291294
high=predictive_mean + 100 * predictive_stddev)
292295
return limits[0], limits[1], predictive_mean
296+
297+
298+
def plot_predictions(predictions,
299+
view_date_begin=None,
300+
view_date_end=None,
301+
ax=None):
302+
"""Creates a plot of the observed series and model predictions.
303+
304+
Creates a `matplotlib` plot of the observed time series with intervals from
305+
the model, predictions for any unobserved points, and the locations of any
306+
anomalies.
307+
308+
Args:
309+
predictions: instance of `PredictionOutput` as returned by
310+
`detect_anomalies`. This should contain predictions for a single series
311+
with no batch dimensions.
312+
view_date_begin: Optional `datetime.datetime` instance.
313+
view_date_end: Optional `datetime.datetime` instance.
314+
ax: Optional `matplotlib` figure axis.
315+
"""
316+
# pylint: disable=g-import-not-at-top
317+
from matplotlib import pylab as plt
318+
from matplotlib import dates as mdates
319+
# pylint: enable=g-import-not-at-top
320+
321+
if len(predictions.observed_time_series.shape) > 1:
322+
raise ValueError('Time series must be one-dimensional; batches are not '
323+
'supported. Saw shape: {}.'.format(
324+
predictions.observed_time_series.shape))
325+
326+
num_steps = len(predictions.times)
327+
time_delta = predictions.times[1] - predictions.times[0]
328+
time_period_length = predictions.times[-1] - predictions.times[0]
329+
if view_date_begin is None:
330+
view_date_begin = predictions.times[0] - 0.04 * time_period_length
331+
if view_date_end is None:
332+
view_date_end = predictions.times[-1] + 0.04 * time_period_length
333+
if not (isinstance(view_date_begin, datetime.datetime)
334+
and isinstance(view_date_end, datetime.datetime)):
335+
raise ValueError('View date start and end must be `datetime.datetime` '
336+
'instances.')
337+
338+
if ax is None: # Create default axis.
339+
fig = plt.figure(figsize=(15, 5), constrained_layout=True)
340+
ax = fig.add_subplot(1, 1, 1)
341+
342+
# Plot series with upper and lower limits.
343+
ax.plot(predictions.times,
344+
predictions.observed_time_series,
345+
color='black', alpha=0.8)
346+
ax.fill_between(predictions.times,
347+
predictions.lower_limit,
348+
predictions.upper_limit,
349+
color='tab:blue', alpha=0.3)
350+
# At steps where no time series was observed, plot the predictive mean.
351+
ax.plot(predictions.times,
352+
np.where(np.isnan(predictions.observed_time_series),
353+
predictions.mean,
354+
np.nan),
355+
color='black', alpha=0.8, ls='--')
356+
357+
# Highlight anomalies.
358+
for anomaly_idx in np.flatnonzero(predictions.is_anomaly):
359+
x = predictions.times[anomaly_idx]
360+
y = predictions.observed_time_series[anomaly_idx]
361+
ax.scatter(x, y, s=100, alpha=0.4, c='red')
362+
ax.annotate(str(x), (x, y))
363+
ax.set_ylabel('Series')
364+
ax.label_outer()
365+
366+
# Use smart date formatting for the x axis.
367+
locator = mdates.AutoDateLocator(minticks=3, maxticks=7)
368+
ax.xaxis.set_major_locator(locator)
369+
ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator))
370+
ax.set_xlim([view_date_begin, view_date_end])
371+
ax.grid(True, color='whitesmoke')
372+
# Display the grid *underneath* the rest of the plot
373+
# (see https://github.com/matplotlib/matplotlib/issues/5045/).
374+
ax.set_axisbelow(True)

tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,23 @@ def test_runs_with_xla(self):
8181
num_samples=5,
8282
jit_compile=True)
8383

84+
def test_plot_predictions_runs(self):
85+
series = self._build_test_series(shape=[28], freq=pd.DateOffset(days=1))
86+
predictions = anomaly_detection.detect_anomalies(
87+
series, anomaly_threshold=0.01, use_gibbs_predictive_dist=False,
88+
seed=test_util.test_seed(sampler_type='stateless'),
89+
num_warmup_steps=5,
90+
num_samples=5)
91+
predictions = tf.nest.map_structure(
92+
lambda x: self.evaluate(x) if tf.is_tensor(x) else x, predictions)
93+
anomaly_detection.plot_predictions(predictions)
94+
95+
batch_predictions = tf.nest.map_structure(
96+
lambda x: np.stack([x, x], axis=0) if isinstance(x, np.ndarray) else x,
97+
predictions)
98+
with self.assertRaisesRegex(ValueError, 'must be one-dimensional'):
99+
anomaly_detection.plot_predictions(batch_predictions)
100+
84101
def test_adapts_to_series_scale(self):
85102
# Create a batch of two series with very different means and stddevs.
86103
freq = pd.DateOffset(hours=1)

0 commit comments

Comments
 (0)