|
15 | 15 | """Utilities for anomaly detection with STS models and Gibbs sampling."""
|
16 | 16 |
|
17 | 17 | import collections
|
| 18 | +import datetime |
18 | 19 |
|
| 20 | +import numpy as np |
19 | 21 | import tensorflow.compat.v2 as tf
|
20 | 22 |
|
21 | 23 | from tensorflow_probability.python import distributions as tfd
|
|
31 | 33 | 'PredictionOutput',
|
32 | 34 | 'detect_anomalies',
|
33 | 35 | 'compute_predictive_bounds',
|
| 36 | + 'plot_predictions' |
34 | 37 | ]
|
35 | 38 |
|
36 | 39 |
|
@@ -290,3 +293,82 @@ def compute_predictive_bounds(predictive_dist, anomaly_threshold=0.01):
|
290 | 293 | low=predictive_mean - 100 * predictive_stddev,
|
291 | 294 | high=predictive_mean + 100 * predictive_stddev)
|
292 | 295 | return limits[0], limits[1], predictive_mean
|
| 296 | + |
| 297 | + |
| 298 | +def plot_predictions(predictions, |
| 299 | + view_date_begin=None, |
| 300 | + view_date_end=None, |
| 301 | + ax=None): |
| 302 | + """Creates a plot of the observed series and model predictions. |
| 303 | +
|
| 304 | + Creates a `matplotlib` plot of the observed time series with intervals from |
| 305 | + the model, predictions for any unobserved points, and the locations of any |
| 306 | + anomalies. |
| 307 | +
|
| 308 | + Args: |
| 309 | + predictions: instance of `PredictionOutput` as returned by |
| 310 | + `detect_anomalies`. This should contain predictions for a single series |
| 311 | + with no batch dimensions. |
| 312 | + view_date_begin: Optional `datetime.datetime` instance. |
| 313 | + view_date_end: Optional `datetime.datetime` instance. |
| 314 | + ax: Optional `matplotlib` figure axis. |
| 315 | + """ |
| 316 | + # pylint: disable=g-import-not-at-top |
| 317 | + from matplotlib import pylab as plt |
| 318 | + from matplotlib import dates as mdates |
| 319 | + # pylint: enable=g-import-not-at-top |
| 320 | + |
| 321 | + if len(predictions.observed_time_series.shape) > 1: |
| 322 | + raise ValueError('Time series must be one-dimensional; batches are not ' |
| 323 | + 'supported. Saw shape: {}.'.format( |
| 324 | + predictions.observed_time_series.shape)) |
| 325 | + |
| 326 | + num_steps = len(predictions.times) |
| 327 | + time_delta = predictions.times[1] - predictions.times[0] |
| 328 | + time_period_length = predictions.times[-1] - predictions.times[0] |
| 329 | + if view_date_begin is None: |
| 330 | + view_date_begin = predictions.times[0] - 0.04 * time_period_length |
| 331 | + if view_date_end is None: |
| 332 | + view_date_end = predictions.times[-1] + 0.04 * time_period_length |
| 333 | + if not (isinstance(view_date_begin, datetime.datetime) |
| 334 | + and isinstance(view_date_end, datetime.datetime)): |
| 335 | + raise ValueError('View date start and end must be `datetime.datetime` ' |
| 336 | + 'instances.') |
| 337 | + |
| 338 | + if ax is None: # Create default axis. |
| 339 | + fig = plt.figure(figsize=(15, 5), constrained_layout=True) |
| 340 | + ax = fig.add_subplot(1, 1, 1) |
| 341 | + |
| 342 | + # Plot series with upper and lower limits. |
| 343 | + ax.plot(predictions.times, |
| 344 | + predictions.observed_time_series, |
| 345 | + color='black', alpha=0.8) |
| 346 | + ax.fill_between(predictions.times, |
| 347 | + predictions.lower_limit, |
| 348 | + predictions.upper_limit, |
| 349 | + color='tab:blue', alpha=0.3) |
| 350 | + # At steps where no time series was observed, plot the predictive mean. |
| 351 | + ax.plot(predictions.times, |
| 352 | + np.where(np.isnan(predictions.observed_time_series), |
| 353 | + predictions.mean, |
| 354 | + np.nan), |
| 355 | + color='black', alpha=0.8, ls='--') |
| 356 | + |
| 357 | + # Highlight anomalies. |
| 358 | + for anomaly_idx in np.flatnonzero(predictions.is_anomaly): |
| 359 | + x = predictions.times[anomaly_idx] |
| 360 | + y = predictions.observed_time_series[anomaly_idx] |
| 361 | + ax.scatter(x, y, s=100, alpha=0.4, c='red') |
| 362 | + ax.annotate(str(x), (x, y)) |
| 363 | + ax.set_ylabel('Series') |
| 364 | + ax.label_outer() |
| 365 | + |
| 366 | + # Use smart date formatting for the x axis. |
| 367 | + locator = mdates.AutoDateLocator(minticks=3, maxticks=7) |
| 368 | + ax.xaxis.set_major_locator(locator) |
| 369 | + ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator)) |
| 370 | + ax.set_xlim([view_date_begin, view_date_end]) |
| 371 | + ax.grid(True, color='whitesmoke') |
| 372 | + # Display the grid *underneath* the rest of the plot |
| 373 | + # (see https://github.com/matplotlib/matplotlib/issues/5045/). |
| 374 | + ax.set_axisbelow(True) |
0 commit comments