12
12
import os
13
13
from collections import defaultdict
14
14
from contextlib import contextmanager
15
+ from dataclasses import dataclass
16
+ from functools import total_ordering
15
17
from time import perf_counter
16
18
from typing import (
17
19
Any ,
22
24
Protocol ,
23
25
runtime_checkable ,
24
26
Sequence ,
25
- Tuple ,
26
27
)
27
28
28
29
import numpy as np
33
34
from torch .distributed .distributed_c10d import Work
34
35
from torchtnt .utils .distributed import PGWrapper
35
36
36
- logger : logging .Logger = logging .getLogger (__name__ )
37
37
38
- _TABLE_ROW = Tuple [str , float , int , float , float ]
39
- _TABLE_DATA = List [_TABLE_ROW ]
38
+ logger : logging .Logger = logging .getLogger (__name__ )
40
39
41
40
42
41
@contextmanager
@@ -69,6 +68,28 @@ def log_elapsed_time(
69
68
logger .info (f"{ action_name } took { interval_time } seconds" )
70
69
71
70
71
+ @total_ordering
72
+ @dataclass
73
+ class TimedActionStats :
74
+ """Dataclass for storing timed action stats. These can be consumed by report generation methods, so metrics should be aggregated."""
75
+
76
+ action_name : str
77
+ mean_duration : float = 0.0
78
+ num_calls : int = 0
79
+ total_duration : float = 0.0
80
+ percentage_of_total_time : float = 0.0
81
+
82
+ def __le__ (self , other : "TimedActionStats" ) -> bool :
83
+ return self .percentage_of_total_time <= other .percentage_of_total_time
84
+
85
+
86
+ @dataclass
87
+ class TimerReport :
88
+ timed_action_stats : List [TimedActionStats ]
89
+ total_calls : int
90
+ total_duration : float
91
+
92
+
72
93
@runtime_checkable
73
94
class TimerProtocol (Protocol ):
74
95
"""
@@ -194,31 +215,30 @@ def _apply_bounds(self, action_name: str) -> None:
194
215
)
195
216
196
217
197
- def _get_total_time ( timer : TimerProtocol ) -> float :
218
+ def _make_report ( self : TimerProtocol ) -> TimerReport :
198
219
total_time = 0.0
199
- for _ , durations in timer .recorded_durations .items ():
220
+ for _ , durations in self .recorded_durations .items ():
200
221
array_value = np .array (durations )
201
222
array_sum = np .sum (array_value )
202
223
total_time += array_sum
203
224
204
- return total_time
205
-
206
-
207
- def _make_report (timer : TimerProtocol ) -> Tuple [_TABLE_DATA , float , float ]:
208
- total_time = _get_total_time (timer )
209
- report = [
210
- (
211
- a ,
212
- np .mean (d ),
213
- len (d ),
214
- np .sum (d ),
215
- 100.0 * np .sum (d ) / total_time ,
225
+ action_stats = [
226
+ TimedActionStats (
227
+ action_name = a ,
228
+ mean_duration = np .mean (d ),
229
+ num_calls = len (d ),
230
+ total_duration = np .sum (d ),
231
+ percentage_of_total_time = 100.0 * np .sum (d ) / total_time ,
216
232
)
217
- for a , d in timer .recorded_durations .items ()
233
+ for a , d in self .recorded_durations .items ()
218
234
]
219
- report .sort (key = lambda x : x [4 ], reverse = True )
220
- total_calls = sum (x [2 ] for x in report )
221
- return report , total_calls , total_time
235
+ action_stats .sort (reverse = True )
236
+ total_calls = sum (x .num_calls for x in action_stats )
237
+ return TimerReport (
238
+ timed_action_stats = action_stats ,
239
+ total_calls = total_calls ,
240
+ total_duration = total_time ,
241
+ )
222
242
223
243
224
244
def get_timer_summary (timer : TimerProtocol ) -> str :
@@ -231,13 +251,16 @@ def get_timer_summary(timer: TimerProtocol) -> str:
231
251
ValueError
232
252
If the input Timer has no recorded actions
233
253
"""
254
+ report : TimerReport = _make_report (timer )
255
+
234
256
sep : str = os .linesep
235
257
output_string = f"Timer Report{ sep } "
236
258
237
- if len (timer .recorded_durations ) == 0 :
259
+ # Handle empty timer case
260
+ if not report .timed_action_stats :
238
261
return output_string
239
262
240
- max_key = max (len (k ) for k in timer . recorded_durations . keys () )
263
+ max_key = max (len (a . action_name ) for a in report . timed_action_stats )
241
264
242
265
# pyre-fixme[53]: Captured variable `max_key` is not annotated.
243
266
def log_row (action : str , mean : str , num_calls : str , total : str , per : str ) -> str :
@@ -252,32 +275,23 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str
252
275
"Total time (s)" ,
253
276
"Percentage %" ,
254
277
)
278
+
255
279
output_string_len = len (header_string .expandtabs ()) - 1
256
280
sep_lines = f"{ sep } { '-' * output_string_len } "
257
281
output_string += sep_lines + header_string + sep_lines
258
- report : _TABLE_DATA
259
- (
260
- report ,
261
- total_calls ,
262
- total_duration ,
263
- ) = _make_report (timer )
282
+
264
283
output_string += log_row (
265
- "Total" , "-" , f"{ total_calls :} " , f"{ total_duration :.5} " , "100 %"
284
+ "Total" , "-" , f"{ report . total_calls :} " , f"{ report . total_duration :.5} " , "100 %"
266
285
)
267
286
output_string += sep_lines
268
- for (
269
- action ,
270
- mean_duration ,
271
- num_calls ,
272
- total_duration ,
273
- duration_per ,
274
- ) in report :
287
+
288
+ for action in report .timed_action_stats :
275
289
output_string += log_row (
276
- action ,
277
- f"{ mean_duration :.5} " ,
278
- f"{ num_calls } " ,
279
- f"{ total_duration :.5} " ,
280
- f"{ duration_per :.5} " ,
290
+ action . action_name ,
291
+ f"{ action . mean_duration :.5} " ,
292
+ f"{ action . num_calls } " ,
293
+ f"{ action . total_duration :.5} " ,
294
+ f"{ action . percentage_of_total_time :.5} " ,
281
295
)
282
296
output_string += sep_lines
283
297
0 commit comments