Skip to content

Commit 6f0f527

Browse files
davmretensorflower-gardener
authored andcommitted
Add STS module for end-to-end anomaly detection with Gibbs sampling.
PiperOrigin-RevId: 387237124
1 parent ebb151f commit 6f0f527

File tree

5 files changed

+542
-0
lines changed

5 files changed

+542
-0
lines changed

tensorflow_probability/python/sts/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ py_library(
3535
":regularization",
3636
":structural_time_series",
3737
"//tensorflow_probability/python/internal:all_util",
38+
"//tensorflow_probability/python/sts/anomaly_detection",
3839
"//tensorflow_probability/python/sts/components",
3940
],
4041
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2020 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# Description:
16+
# Gibbs sampling for Bayesian structural time series models
17+
18+
licenses(["notice"])
19+
20+
package(
21+
default_visibility = [
22+
"//tensorflow_probability:__subpackages__",
23+
],
24+
)
25+
26+
exports_files(["LICENSE"])
27+
28+
py_library(
29+
name = "anomaly_detection",
30+
srcs = ["__init__.py"],
31+
srcs_version = "PY3",
32+
deps = [
33+
":anomaly_detection_lib",
34+
],
35+
)
36+
37+
py_library(
38+
name = "anomaly_detection_lib",
39+
srcs = ["anomaly_detection_lib.py"],
40+
srcs_version = "PY3",
41+
deps = [
42+
# numpy dep,
43+
# tensorflow dep,
44+
"//tensorflow_probability/python/distributions",
45+
"//tensorflow_probability/python/experimental/sts_gibbs:gibbs_sampler",
46+
"//tensorflow_probability/python/experimental/util",
47+
"//tensorflow_probability/python/internal:distribution_util",
48+
"//tensorflow_probability/python/internal:prefer_static",
49+
"//tensorflow_probability/python/math",
50+
"//tensorflow_probability/python/sts:fitting",
51+
"//tensorflow_probability/python/sts:forecast",
52+
"//tensorflow_probability/python/sts:regularization",
53+
"//tensorflow_probability/python/sts/components",
54+
"//tensorflow_probability/python/sts/internal",
55+
],
56+
)
57+
58+
py_test(
59+
name = "anomaly_detection_test",
60+
size = "medium",
61+
srcs = ["anomaly_detection_test.py"],
62+
srcs_version = "PY3",
63+
deps = [
64+
# absl/testing:parameterized dep,
65+
# numpy dep,
66+
# pandas dep,
67+
# tensorflow dep,
68+
"//tensorflow_probability",
69+
"//tensorflow_probability/python/experimental/sts_gibbs",
70+
"//tensorflow_probability/python/internal:test_util",
71+
],
72+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2020 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Anomaly detection with structural time series models."""
16+
17+
from tensorflow_probability.python.internal import all_util
18+
from tensorflow_probability.python.sts.anomaly_detection.anomaly_detection_lib import detect_anomalies
19+
from tensorflow_probability.python.sts.anomaly_detection.anomaly_detection_lib import PredictionOutput
20+
21+
_allowed_symbols = [
22+
'detect_anomalies',
23+
'PredictionOutput'
24+
]
25+
26+
all_util.remove_undocumented(__name__, _allowed_symbols)
27+

0 commit comments

Comments
 (0)