Skip to content

Commit fa02938

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add AggregatedTimer implementation (#1013)
Summary: Pull Request resolved: #1013 Reviewed By: JKSenthil Differential Revision: D77345641 fbshipit-source-id: 72c017887e0a5c0cf8627442c3d2d6e2a38a1413
1 parent 62388fd commit fa02938

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

tests/utils/test_timer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchtnt.utils.distributed import spawn_multi_process
2020
from torchtnt.utils.test_utils import skip_if_not_distributed
2121
from torchtnt.utils.timer import (
22+
AggregatedTimer,
2223
BoundedTimer,
2324
FullSyncPeriodicTimer,
2425
get_durations_histogram,
@@ -235,6 +236,69 @@ def test_get_recorded_durations_table(self) -> None:
235236
== "\n| Name | p50 | p90 |\n|:-------|------:|------:|\n| op | 1 | 2 |"
236237
)
237238

239+
def test_aggregated_timer_statistics(self) -> None:
240+
"""Test that AggregatedTimer maintains correct statistics"""
241+
timer = AggregatedTimer()
242+
243+
# Time an action multiple times with known durations
244+
sleep_times = [1.0, 1.5, 2, 2.5, 1.0]
245+
for sleep_time in sleep_times:
246+
with timer.time("test_action"):
247+
time.sleep(sleep_time)
248+
249+
with timer.time("fast_action"):
250+
time.sleep(0.05)
251+
252+
# Generate report to calculate mean_duration
253+
report = timer._make_report()
254+
# test_action should be first
255+
test_action_report = report.timed_action_stats[0]
256+
fast_action_report = report.timed_action_stats[1]
257+
258+
# Verify for test_action
259+
self.assertEqual(test_action_report.num_calls, 5)
260+
expected_total = sum(sleep_times)
261+
self.assert_within_tolerance(
262+
test_action_report.total_duration, expected_total, 20
263+
)
264+
self.assert_within_tolerance(
265+
test_action_report.mean_duration, expected_total / 5
266+
)
267+
268+
# Verify for fast_action
269+
self.assertEqual(fast_action_report.num_calls, 1)
270+
self.assert_within_tolerance(fast_action_report.total_duration, 0.05)
271+
self.assertGreater(
272+
test_action_report.mean_duration, fast_action_report.mean_duration
273+
)
274+
275+
def test_aggregated_timer_multiple_actions(self) -> None:
276+
timer = AggregatedTimer()
277+
278+
actions = ["dataloader", "forward", "backward", "optimizer"]
279+
call_counts = [10, 10, 10, 5]
280+
281+
for action, count in zip(actions, call_counts):
282+
for _ in range(count):
283+
with timer.time(action):
284+
time.sleep(0.001)
285+
286+
# Check that all actions are recorded
287+
self.assertEqual(len(timer._aggregate_stats), 4)
288+
289+
for action, expected_count in zip(actions, call_counts):
290+
self.assertEqual(timer._aggregate_stats[action].num_calls, expected_count)
291+
292+
# Check report
293+
report = timer._make_report()
294+
self.assertEqual(len(report.timed_action_stats), 4)
295+
self.assertEqual(report.total_calls, sum(call_counts))
296+
297+
total_percentage = sum(
298+
stats.percentage_of_total_time for stats in report.timed_action_stats
299+
)
300+
self.assert_within_tolerance(total_percentage, 100.0, 1)
301+
238302

239303
class FullSyncPeriodicTimerTest(unittest.TestCase):
240304
@classmethod

torchtnt/utils/timer.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,75 @@ def _make_report(self: TimerProtocol) -> TimerReport:
208208
)
209209

210210

211+
class AggregatedTimer(Timer):
212+
"""
213+
A Timer class which implements TimerProtocol and stores aggregated timing stats. Instead of
214+
storing the recorded durations for each action, it accumulates running metrics and computes
215+
final stats at the end when generating the report. This is useful for cases where the number
216+
of samples is too large to store in memory.
217+
"""
218+
219+
def __init__(
220+
self,
221+
cuda_sync: Optional[bool] = None,
222+
verbose: bool = False,
223+
) -> None:
224+
super().__init__(cuda_sync=cuda_sync, verbose=verbose)
225+
self._aggregate_stats: Dict[str, TimedActionStats] = defaultdict(
226+
lambda: TimedActionStats(action_name="") # Filled in on time() method
227+
)
228+
229+
@contextmanager
230+
def time(
231+
self,
232+
action_name: str,
233+
) -> Generator[None, None, None]:
234+
# Run base class context manager first
235+
with super().time(action_name):
236+
yield
237+
238+
# Update aggregate stats
239+
latest_duration: float = self.recorded_durations[action_name][-1]
240+
self._aggregate_stats[action_name].action_name = action_name
241+
self._aggregate_stats[action_name].num_calls += 1
242+
self._aggregate_stats[action_name].total_duration += latest_duration
243+
244+
# Reset recorded durations to avoid storing data
245+
self.recorded_durations.clear()
246+
247+
def _make_report(self) -> TimerReport:
248+
"""
249+
Creates the report but considering that the data is aggregated in the correct structure.
250+
"""
251+
total_time = 0.0
252+
total_calls = 0
253+
254+
# Calculate total time and calls across all actions
255+
for stats in self._aggregate_stats.values():
256+
total_time += stats.total_duration
257+
total_calls += stats.num_calls
258+
259+
# Build report data
260+
action_stats: List[TimedActionStats] = []
261+
for stats in self._aggregate_stats.values():
262+
stats.mean_duration = (
263+
stats.total_duration / stats.num_calls if stats.num_calls > 0 else 0.0
264+
)
265+
stats.percentage_of_total_time = (
266+
100.0 * stats.total_duration / total_time if total_time > 0 else 0.0
267+
)
268+
action_stats.append(stats)
269+
270+
# Sort by percentage (descending)
271+
action_stats.sort(key=lambda x: x.percentage_of_total_time, reverse=True)
272+
273+
return TimerReport(
274+
timed_action_stats=action_stats,
275+
total_calls=total_calls,
276+
total_duration=total_time,
277+
)
278+
279+
211280
class BoundedTimer(Timer):
212281
"""
213282
A Timer class which implements TimerProtocol and stores timings in a dictionary `recorded_durations`.

0 commit comments

Comments
 (0)