Skip to content

Commit e63ec2b

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Refactor _make_report methods with better interface (#1011)
Summary: Pull Request resolved: #1011 Reviewed By: JKSenthil Differential Revision: D77345640 fbshipit-source-id: ede1b6331c68b62906c7339d84e207d4257e1491
1 parent f1ebb63 commit e63ec2b

File tree

1 file changed

+57
-43
lines changed

1 file changed

+57
-43
lines changed

torchtnt/utils/timer.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import os
1313
from collections import defaultdict
1414
from contextlib import contextmanager
15+
from dataclasses import dataclass
16+
from functools import total_ordering
1517
from time import perf_counter
1618
from typing import (
1719
Any,
@@ -22,7 +24,6 @@
2224
Protocol,
2325
runtime_checkable,
2426
Sequence,
25-
Tuple,
2627
)
2728

2829
import numpy as np
@@ -33,10 +34,8 @@
3334
from torch.distributed.distributed_c10d import Work
3435
from torchtnt.utils.distributed import PGWrapper
3536

36-
logger: logging.Logger = logging.getLogger(__name__)
3737

38-
_TABLE_ROW = Tuple[str, float, int, float, float]
39-
_TABLE_DATA = List[_TABLE_ROW]
38+
logger: logging.Logger = logging.getLogger(__name__)
4039

4140

4241
@contextmanager
@@ -69,6 +68,28 @@ def log_elapsed_time(
6968
logger.info(f"{action_name} took {interval_time} seconds")
7069

7170

71+
@total_ordering
72+
@dataclass
73+
class TimedActionStats:
74+
"""Dataclass for storing timed action stats. These can be consumed by report generation methods, so metrics should be aggregated."""
75+
76+
action_name: str
77+
mean_duration: float = 0.0
78+
num_calls: int = 0
79+
total_duration: float = 0.0
80+
percentage_of_total_time: float = 0.0
81+
82+
def __le__(self, other: "TimedActionStats") -> bool:
83+
return self.percentage_of_total_time <= other.percentage_of_total_time
84+
85+
86+
@dataclass
87+
class TimerReport:
88+
timed_action_stats: List[TimedActionStats]
89+
total_calls: int
90+
total_duration: float
91+
92+
7293
@runtime_checkable
7394
class TimerProtocol(Protocol):
7495
"""
@@ -194,31 +215,30 @@ def _apply_bounds(self, action_name: str) -> None:
194215
)
195216

196217

197-
def _get_total_time(timer: TimerProtocol) -> float:
218+
def _make_report(self: TimerProtocol) -> TimerReport:
198219
total_time = 0.0
199-
for _, durations in timer.recorded_durations.items():
220+
for _, durations in self.recorded_durations.items():
200221
array_value = np.array(durations)
201222
array_sum = np.sum(array_value)
202223
total_time += array_sum
203224

204-
return total_time
205-
206-
207-
def _make_report(timer: TimerProtocol) -> Tuple[_TABLE_DATA, float, float]:
208-
total_time = _get_total_time(timer)
209-
report = [
210-
(
211-
a,
212-
np.mean(d),
213-
len(d),
214-
np.sum(d),
215-
100.0 * np.sum(d) / total_time,
225+
action_stats = [
226+
TimedActionStats(
227+
action_name=a,
228+
mean_duration=np.mean(d),
229+
num_calls=len(d),
230+
total_duration=np.sum(d),
231+
percentage_of_total_time=100.0 * np.sum(d) / total_time,
216232
)
217-
for a, d in timer.recorded_durations.items()
233+
for a, d in self.recorded_durations.items()
218234
]
219-
report.sort(key=lambda x: x[4], reverse=True)
220-
total_calls = sum(x[2] for x in report)
221-
return report, total_calls, total_time
235+
action_stats.sort(reverse=True)
236+
total_calls = sum(x.num_calls for x in action_stats)
237+
return TimerReport(
238+
timed_action_stats=action_stats,
239+
total_calls=total_calls,
240+
total_duration=total_time,
241+
)
222242

223243

224244
def get_timer_summary(timer: TimerProtocol) -> str:
@@ -231,13 +251,16 @@ def get_timer_summary(timer: TimerProtocol) -> str:
231251
ValueError
232252
If the input Timer has no recorded actions
233253
"""
254+
report: TimerReport = _make_report(timer)
255+
234256
sep: str = os.linesep
235257
output_string = f"Timer Report{sep}"
236258

237-
if len(timer.recorded_durations) == 0:
259+
# Handle empty timer case
260+
if not report.timed_action_stats:
238261
return output_string
239262

240-
max_key = max(len(k) for k in timer.recorded_durations.keys())
263+
max_key = max(len(a.action_name) for a in report.timed_action_stats)
241264

242265
# pyre-fixme[53]: Captured variable `max_key` is not annotated.
243266
def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
@@ -252,32 +275,23 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str
252275
"Total time (s)",
253276
"Percentage %",
254277
)
278+
255279
output_string_len = len(header_string.expandtabs()) - 1
256280
sep_lines = f"{sep}{'-' * output_string_len}"
257281
output_string += sep_lines + header_string + sep_lines
258-
report: _TABLE_DATA
259-
(
260-
report,
261-
total_calls,
262-
total_duration,
263-
) = _make_report(timer)
282+
264283
output_string += log_row(
265-
"Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %"
284+
"Total", "-", f"{report.total_calls:}", f"{report.total_duration:.5}", "100 %"
266285
)
267286
output_string += sep_lines
268-
for (
269-
action,
270-
mean_duration,
271-
num_calls,
272-
total_duration,
273-
duration_per,
274-
) in report:
287+
288+
for action in report.timed_action_stats:
275289
output_string += log_row(
276-
action,
277-
f"{mean_duration:.5}",
278-
f"{num_calls}",
279-
f"{total_duration:.5}",
280-
f"{duration_per:.5}",
290+
action.action_name,
291+
f"{action.mean_duration:.5}",
292+
f"{action.num_calls}",
293+
f"{action.total_duration:.5}",
294+
f"{action.percentage_of_total_time:.5}",
281295
)
282296
output_string += sep_lines
283297

0 commit comments

Comments
 (0)