Skip to content

Commit 9c08360

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Implement starter anomaly evaluators (#853)
Summary: Pull Request resolved: #853 ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To get started with anomaly detection, let's first define two evaluators: - Threshold is the most intuitive one, and checks that a metric value is within a predefined range. - IsNaN would be useful to catch fast cases where the loss is NaN because of bad inputs. Later on we can implement more interesting evaluators like outliers, changepoint detection, etc. if needed. Reviewed By: JKSenthil Differential Revision: D58564199 fbshipit-source-id: 767c3bf17f8aae5189a545a862d6098402ea34a9
1 parent ddf8cd5 commit 9c08360

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import math
11+
import unittest
12+
13+
from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator
14+
15+
16+
class TestAnomalyLogger(unittest.TestCase):
17+
18+
def test_threshold(self) -> None:
19+
threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9)
20+
self.assertFalse(threshold.is_anomaly())
21+
22+
threshold.update(0.4)
23+
self.assertTrue(threshold.is_anomaly())
24+
25+
threshold.update(0.6)
26+
self.assertFalse(threshold.is_anomaly())
27+
28+
threshold.update(0.95)
29+
self.assertTrue(threshold.is_anomaly())
30+
31+
threshold = ThresholdEvaluator(max_val=1)
32+
33+
threshold.update(100.0)
34+
self.assertTrue(threshold.is_anomaly())
35+
36+
threshold.update(-500.0)
37+
self.assertFalse(threshold.is_anomaly())
38+
39+
def test_isnan(self) -> None:
40+
isnan = IsNaNEvaluator()
41+
self.assertFalse(isnan.is_anomaly())
42+
43+
isnan.update(0.4)
44+
self.assertFalse(isnan.is_anomaly())
45+
46+
isnan.update(math.nan)
47+
self.assertTrue(isnan.is_anomaly())

torchtnt/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from .anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator
910
from .checkpoint import (
1011
BestCheckpointConfig,
1112
CheckpointManager,
@@ -88,6 +89,8 @@
8889
)
8990

9091
__all__ = [
92+
"IsNaNEvaluator",
93+
"ThresholdEvaluator",
9194
"CheckpointPath",
9295
"MetricData",
9396
"get_best_checkpoint_path",

torchtnt/utils/anomaly_evaluation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
import logging
12+
import math
1213
from abc import ABC, abstractmethod
14+
from math import inf
1315

1416
_logger: logging.Logger = logging.getLogger(__name__)
1517

@@ -49,3 +51,49 @@ def is_anomaly(self) -> bool:
4951
an anomaly detection algorithm.
5052
"""
5153
pass
54+
55+
56+
class ThresholdEvaluator(MetricAnomalyEvaluator):
57+
"""
58+
Evaluates whether a metric value is anomalous based on a predefined threshold.
59+
"""
60+
61+
def __init__(
62+
self,
63+
*,
64+
min_val: float = -inf,
65+
max_val: float = inf,
66+
) -> None:
67+
"""
68+
Args:
69+
min_val: Minimum allowed value. Default value is -inf.
70+
max_val: Maximum allowed value. Default value is inf.
71+
warmup_steps: Number of steps to ignore before evaluating anomalies. Default value is 0.
72+
evaluate_every_n_steps: Step interval to wait in between anomaly evaluations. Default value is 1.
73+
"""
74+
self.min_val = min_val
75+
self.max_val = max_val
76+
self.curr_val: float = min_val
77+
78+
def update(self, value: float) -> None:
79+
self.curr_val = value
80+
81+
def is_anomaly(self) -> bool:
82+
return not self.min_val <= self.curr_val <= self.max_val
83+
84+
85+
class IsNaNEvaluator(MetricAnomalyEvaluator):
86+
"""
87+
Evaluates whether a metric value is NaN.
88+
"""
89+
90+
def __init__(
91+
self,
92+
) -> None:
93+
self.curr_val: float = 0
94+
95+
def update(self, value: float) -> None:
96+
self.curr_val = value
97+
98+
def is_anomaly(self) -> bool:
99+
return math.isnan(self.curr_val)

0 commit comments

Comments
 (0)