Skip to content

Commit 12c5637

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add anomaly detection support to TensorboardLogger (#854)
Summary: Pull Request resolved: #854 ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To start leveraging the AnomalyLogger as easily as possible, let's make it the base class for the Tensorboard logger instead of MetricLogger. This will have no effect unless users specify the `tracked_metrics` attribute, which is optional. However, if they do want to use it, they have to make very little changes. Next diff will do the same for the AIXLogger Reviewed By: JKSenthil Differential Revision: D58593222 fbshipit-source-id: bff900f9e3ce15640628f22a976166f6db59293c
1 parent 6fb7c5a commit 12c5637

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

tests/utils/loggers/test_tensorboard.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,53 @@
1111

1212
import tempfile
1313
import unittest
14-
from unittest.mock import Mock, patch
14+
from unittest.mock import MagicMock, Mock, patch
1515

1616
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
17+
from torchtnt.utils.anomaly_evaluation import ThresholdEvaluator
18+
from torchtnt.utils.loggers.anomaly_logger import TrackedMetric
1719

1820
from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
1921

2022

2123
class TensorBoardLoggerTest(unittest.TestCase):
22-
def test_log(self: TensorBoardLoggerTest) -> None:
24+
25+
@patch(
26+
"torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected",
27+
)
28+
def test_log(
29+
self: TensorBoardLoggerTest, mock_on_anomaly_detected: MagicMock
30+
) -> None:
2331
with tempfile.TemporaryDirectory() as log_dir:
24-
logger = TensorBoardLogger(path=log_dir)
25-
for i in range(5):
26-
logger.log("test_log", float(i) ** 2, i)
27-
logger.close()
32+
logger = TensorBoardLogger(
33+
path=log_dir,
34+
tracked_metrics=[
35+
TrackedMetric(
36+
name="test_log",
37+
anomaly_evaluators=[
38+
ThresholdEvaluator(min_val=25),
39+
],
40+
evaluate_every_n_steps=2,
41+
warmup_steps=2,
42+
)
43+
],
44+
)
45+
warning_container = []
46+
with patch(
47+
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
48+
side_effect=warning_container.append,
49+
):
50+
for i in range(5):
51+
logger.log("test_log", float(i) ** 2, i)
52+
logger.close()
53+
54+
self.assertEqual(
55+
warning_container,
56+
[
57+
"Found anomaly in metric: test_log, with value: 16.0, using evaluator: ThresholdEvaluator"
58+
],
59+
)
60+
mock_on_anomaly_detected.assert_called_with("test_log", 16.0, 4)
2861

2962
acc = EventAccumulator(log_dir)
3063
acc.Reload()

torchtnt/utils/loggers/tensorboard.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,28 @@
1111

1212
import atexit
1313
import logging
14-
from typing import Any, Dict, Mapping, Optional, Union
14+
from typing import Any, Dict, List, Mapping, Optional, Union
1515

1616
from torch.utils.tensorboard import SummaryWriter
1717
from torchtnt.utils.distributed import get_global_rank
18-
from torchtnt.utils.loggers.logger import MetricLogger, Scalar
18+
from torchtnt.utils.loggers.anomaly_logger import AnomalyLogger, TrackedMetric
19+
from torchtnt.utils.loggers.logger import Scalar
1920

2021
logger: logging.Logger = logging.getLogger(__name__)
2122

2223

23-
class TensorBoardLogger(MetricLogger):
24+
class TensorBoardLogger(AnomalyLogger):
2425
"""
2526
Simple logger for TensorBoard.
2627
2728
On construction, the logger creates a new events file that logs
2829
will be written to. If the environment variable `RANK` is defined,
2930
logger will only log if RANK = 0.
3031
32+
Metrics may be tracked for anomaly detection if they are configured in the
33+
optional `tracked_metrics` argument. See :class:`torchtnt.utils.loggers.AnomalyLogger`
34+
for more details.
35+
3136
Note:
3237
If using this logger with distributed training:
3338
@@ -38,6 +43,7 @@ class TensorBoardLogger(MetricLogger):
3843
3944
Args:
4045
path (str): path to write logs to
46+
tracked_metrics: Optional list of TrackedMetric objects to track for anomaly detection.
4147
*args: Extra positional arguments to pass to SummaryWriter
4248
**kwargs: Extra keyword arguments to pass to SummaryWriter
4349
@@ -49,7 +55,14 @@ class TensorBoardLogger(MetricLogger):
4955
logger.close()
5056
"""
5157

52-
def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None:
58+
def __init__(
59+
self: TensorBoardLogger,
60+
path: str,
61+
tracked_metrics: Optional[List[TrackedMetric]] = None,
62+
*args: Any,
63+
**kwargs: Any,
64+
) -> None:
65+
super().__init__(tracked_metrics)
5366
self._writer: Optional[SummaryWriter] = None
5467
self._path: str = path
5568
self._rank: int = get_global_rank()
@@ -100,6 +113,8 @@ def log(self: TensorBoardLogger, name: str, data: Scalar, step: int) -> None:
100113
if self._writer:
101114
self._writer.add_scalar(name, data, global_step=step, new_style=True)
102115

116+
super().log(name, data, step)
117+
103118
def log_text(self: TensorBoardLogger, name: str, data: str, step: int) -> None:
104119
"""Add text data to summary.
105120

0 commit comments

Comments
 (0)