Skip to content

Commit b442b1e

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add event + event handlers for custom logging (#942)
Summary: Pull Request resolved: #942 Reviewed By: diego-urgell Differential Revision: D65491587 fbshipit-source-id: ca29d91f1aae2acca17f286ab4226590857b8c28
1 parent e45d4c4 commit b442b1e

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

torchtnt/utils/event.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
from typing import Dict, Union
9+
10+
EventMetadataValue = Union[str, int, float, bool, None]
11+
12+
13+
@dataclass
14+
class Event:
15+
"""
16+
The class represents the generic event that occurs during a TorchTNT
17+
loop. The event can be any kind of meaningful action.
18+
19+
Args:
20+
name: event name.
21+
metadata: additional data that is associated with the event.
22+
"""
23+
24+
name: str
25+
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)

torchtnt/utils/event_handlers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
import random
10+
from contextlib import contextmanager
11+
from functools import lru_cache
12+
from typing import Dict, Generator, List, Optional
13+
14+
import importlib_metadata
15+
from typing_extensions import Protocol, runtime_checkable
16+
17+
from .event import Event
18+
19+
logger: logging.Logger = logging.getLogger(__name__)
20+
21+
22+
@runtime_checkable
23+
class EventHandler(Protocol):
24+
def handle_event(self, event: Event) -> None: ...
25+
26+
27+
_log_handlers: List[EventHandler] = []
28+
29+
30+
@lru_cache(maxsize=None)
31+
def get_event_handlers() -> List[EventHandler]:
32+
global _log_handlers
33+
34+
# Registered event handlers through entry points
35+
eps = importlib_metadata.entry_points(group="tnt_event_handlers")
36+
for entry in eps:
37+
logger.debug(
38+
f"Attempting to register event handler {entry.name}: {entry.value}"
39+
)
40+
factory = entry.load()
41+
handler = factory()
42+
43+
if not isinstance(handler, EventHandler):
44+
raise RuntimeError(
45+
f"The factory function for {({entry.value})} "
46+
"did not return a EventHandler object."
47+
)
48+
_log_handlers.append(handler)
49+
return _log_handlers
50+
51+
52+
def log_event(event: Event) -> None:
53+
"""
54+
Handle an event.
55+
56+
Args:
57+
event: The event to handle.
58+
"""
59+
60+
for handler in get_event_handlers():
61+
handler.handle_event(event)
62+
63+
64+
@contextmanager
65+
def log_interval(
66+
name: str, metadata: Optional[Dict[str, str]] = None
67+
) -> Generator[None, None, None]:
68+
unique_id = _generate_random_int64()
69+
if metadata is None:
70+
metadata = {}
71+
metadata.update({"action": "start", "unique_id": unique_id})
72+
start_event = Event(name=name, metadata=metadata)
73+
log_event(start_event)
74+
75+
yield
76+
77+
metadata["action"] = "end"
78+
end_event = Event(name=name, metadata=metadata)
79+
log_event(end_event)
80+
81+
82+
def _generate_random_int64() -> int:
83+
# avoid being influenced by externally set seed
84+
local_random = random.Random()
85+
return local_random.randint(0, 2**63 - 1)

0 commit comments

Comments
 (0)