Skip to content

Commit 6fb7c5a

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add logger for anomaly detection (#852)
Summary: Pull Request resolved: #852 ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff After implementing the evaluators, let's add the `AnomalyLogger` class that receives some configuration of metrics to check for. If an anomaly is detected, then it will call an optional `on_anomaly_detected` method that can be overriden by the user. Next diffs will add this to our `AIXLogger` and `TensorboardLogger` as a base class. Reviewed By: JKSenthil Differential Revision: D58564200 fbshipit-source-id: 157ab34b993195b220c5f93941a2427306416d11
1 parent 9c08360 commit 6fb7c5a

File tree

3 files changed

+397
-0
lines changed

3 files changed

+397
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import math
11+
import unittest
12+
from unittest.mock import call, MagicMock, patch
13+
14+
import torch
15+
16+
from torchtnt.utils.anomaly_evaluation import (
17+
IsNaNEvaluator,
18+
MetricAnomalyEvaluator,
19+
ThresholdEvaluator,
20+
)
21+
22+
from torchtnt.utils.loggers.anomaly_logger import AnomalyLogger, TrackedMetric
23+
24+
25+
class DummyEvaluator(MetricAnomalyEvaluator):
26+
def _evaluate_anomaly(self, value: float) -> bool:
27+
return True
28+
29+
30+
class TestAnomalyLogger(unittest.TestCase):
31+
32+
def test_init(self) -> None:
33+
tracked_metrics = [
34+
TrackedMetric(
35+
name="accuracy",
36+
anomaly_evaluators=[ThresholdEvaluator(min_val=0.5, max_val=0.9)],
37+
),
38+
TrackedMetric(
39+
name="accuracy",
40+
anomaly_evaluators=[IsNaNEvaluator()],
41+
),
42+
TrackedMetric(name="loss", anomaly_evaluators=[IsNaNEvaluator()]),
43+
]
44+
45+
warning_container = []
46+
with patch(
47+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
48+
side_effect=warning_container.append,
49+
):
50+
logger = AnomalyLogger(
51+
tracked_metrics=tracked_metrics,
52+
)
53+
54+
self.assertEqual(
55+
warning_container,
56+
["Found multiple configs for metric 'accuracy'. Skipping."],
57+
)
58+
self.assertEqual(set(logger._tracked_metrics.keys()), {"loss"})
59+
60+
@patch(
61+
"torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected",
62+
)
63+
def test_log(self, mock_on_anomaly_detected: MagicMock) -> None:
64+
logger = AnomalyLogger(
65+
tracked_metrics=[
66+
TrackedMetric(
67+
name="accuracy",
68+
anomaly_evaluators=[ThresholdEvaluator(min_val=0.5, max_val=0.9)],
69+
warmup_steps=4,
70+
evaluate_every_n_steps=2,
71+
)
72+
]
73+
)
74+
75+
# Log value that can't be resolved to a single numerical.
76+
warning_container = []
77+
with patch(
78+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
79+
side_effect=warning_container.append,
80+
):
81+
logger.log(step=1, name="accuracy", data=torch.Tensor([0.5, 0.9]))
82+
83+
self.assertEqual(
84+
warning_container,
85+
[
86+
"Error when extracting a single numerical value from the provided metric: Scalar tensor must contain a single item, 2 given."
87+
],
88+
)
89+
mock_on_anomaly_detected.assert_called_once()
90+
91+
# Log anomalous value during warmup: no-op
92+
mock_on_anomaly_detected.reset_mock()
93+
logger.log(step=4, name="accuracy", data=0.2)
94+
mock_on_anomaly_detected.assert_not_called()
95+
96+
# Log anomalous value on non-evaluate step: no-op
97+
logger.log(step=5, name="accuracy", data=0.1)
98+
mock_on_anomaly_detected.assert_not_called()
99+
100+
# Log metric that is not tracked: no-op
101+
mock_on_anomaly_detected.reset_mock()
102+
logger.log(step=6, name="loss", data=math.nan)
103+
mock_on_anomaly_detected.assert_not_called()
104+
105+
# Log metric within threshold: no-op
106+
logger.log(step=6, name="accuracy", data=0.6)
107+
mock_on_anomaly_detected.assert_not_called()
108+
109+
# Log metric outside threshold
110+
warning_container = []
111+
with patch(
112+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
113+
side_effect=warning_container.append,
114+
):
115+
logger.log(step=8, name="accuracy", data=0.95)
116+
117+
self.assertEqual(
118+
warning_container,
119+
[
120+
"Found anomaly in metric: accuracy, with value: 0.95, using evaluator: ThresholdEvaluator"
121+
],
122+
)
123+
mock_on_anomaly_detected.assert_called_with("accuracy", 0.95, 8)
124+
125+
@patch(
126+
"torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected",
127+
)
128+
def test_log_dict(self, mock_on_anomaly_detected: MagicMock) -> None:
129+
logger = AnomalyLogger(
130+
tracked_metrics=[
131+
TrackedMetric(
132+
name="accuracy",
133+
anomaly_evaluators=[ThresholdEvaluator(min_val=0.5, max_val=0.9)],
134+
),
135+
TrackedMetric(
136+
name="loss",
137+
anomaly_evaluators=[IsNaNEvaluator()],
138+
),
139+
TrackedMetric(
140+
name="f1_score",
141+
anomaly_evaluators=[
142+
IsNaNEvaluator(),
143+
ThresholdEvaluator(min_val=0.2),
144+
],
145+
),
146+
]
147+
)
148+
149+
warning_container = []
150+
with patch(
151+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
152+
side_effect=warning_container.append,
153+
):
154+
logger.log_dict(
155+
step=1,
156+
payload={
157+
"loss": math.nan,
158+
"accuracy": 0.63,
159+
"precision": 0.7,
160+
"f1_score": 0.05,
161+
},
162+
)
163+
164+
self.assertEqual(
165+
set(warning_container),
166+
{
167+
"Found anomaly in metric: f1_score, with value: 0.05, using evaluator: ThresholdEvaluator",
168+
"Found anomaly in metric: loss, with value: nan, using evaluator: IsNaNEvaluator",
169+
},
170+
)
171+
172+
expected_anomaly_callback_calls = [
173+
call("f1_score", 0.05, 1),
174+
call("loss", math.nan, 1),
175+
]
176+
mock_on_anomaly_detected.assert_has_calls(
177+
expected_anomaly_callback_calls, any_order=True
178+
)
179+
180+
@patch(
181+
"torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected",
182+
side_effect=Exception("test exception"),
183+
)
184+
def test_on_anomaly_callback_exception(self, _) -> None:
185+
logger = AnomalyLogger(
186+
tracked_metrics=[
187+
TrackedMetric(
188+
name="accuracy",
189+
anomaly_evaluators=[ThresholdEvaluator(min_val=0.5, max_val=0.9)],
190+
),
191+
]
192+
)
193+
194+
warning_container = []
195+
with patch(
196+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
197+
side_effect=warning_container.append,
198+
):
199+
logger.log(step=1, name="accuracy", data=0.95)
200+
201+
self.assertEqual(
202+
warning_container,
203+
[
204+
"Found anomaly in metric: accuracy, with value: 0.95, using evaluator: ThresholdEvaluator",
205+
"Exception when calling on_anomaly_hook: test exception",
206+
],
207+
)

torchtnt/utils/loggers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from .anomaly_logger import AnomalyLogger, TrackedMetric
910
from .csv import CSVLogger
1011
from .file import FileLogger
1112
from .in_memory import InMemoryLogger
@@ -17,6 +18,8 @@
1718

1819

1920
__all__ = [
21+
"AnomalyLogger",
22+
"TrackedMetric",
2023
"CSVLogger",
2124
"FileLogger",
2225
"InMemoryLogger",

0 commit comments

Comments
 (0)