|
19 | 19 | from torchtnt.utils.distributed import spawn_multi_process
|
20 | 20 | from torchtnt.utils.test_utils import skip_if_not_distributed
|
21 | 21 | from torchtnt.utils.timer import (
|
| 22 | + AggregatedTimer, |
22 | 23 | BoundedTimer,
|
23 | 24 | FullSyncPeriodicTimer,
|
24 | 25 | get_durations_histogram,
|
@@ -235,6 +236,69 @@ def test_get_recorded_durations_table(self) -> None:
|
235 | 236 | == "\n| Name | p50 | p90 |\n|:-------|------:|------:|\n| op | 1 | 2 |"
|
236 | 237 | )
|
237 | 238 |
|
| 239 | + def test_aggregated_timer_statistics(self) -> None: |
| 240 | + """Test that AggregatedTimer maintains correct statistics""" |
| 241 | + timer = AggregatedTimer() |
| 242 | + |
| 243 | + # Time an action multiple times with known durations |
| 244 | + sleep_times = [1.0, 1.5, 2, 2.5, 1.0] |
| 245 | + for sleep_time in sleep_times: |
| 246 | + with timer.time("test_action"): |
| 247 | + time.sleep(sleep_time) |
| 248 | + |
| 249 | + with timer.time("fast_action"): |
| 250 | + time.sleep(0.05) |
| 251 | + |
| 252 | + # Generate report to calculate mean_duration |
| 253 | + report = timer._make_report() |
| 254 | + # test_action should be first |
| 255 | + test_action_report = report.timed_action_stats[0] |
| 256 | + fast_action_report = report.timed_action_stats[1] |
| 257 | + |
| 258 | + # Verify for test_action |
| 259 | + self.assertEqual(test_action_report.num_calls, 5) |
| 260 | + expected_total = sum(sleep_times) |
| 261 | + self.assert_within_tolerance( |
| 262 | + test_action_report.total_duration, expected_total, 20 |
| 263 | + ) |
| 264 | + self.assert_within_tolerance( |
| 265 | + test_action_report.mean_duration, expected_total / 5 |
| 266 | + ) |
| 267 | + |
| 268 | + # Verify for fast_action |
| 269 | + self.assertEqual(fast_action_report.num_calls, 1) |
| 270 | + self.assert_within_tolerance(fast_action_report.total_duration, 0.05) |
| 271 | + self.assertGreater( |
| 272 | + test_action_report.mean_duration, fast_action_report.mean_duration |
| 273 | + ) |
| 274 | + |
| 275 | + def test_aggregated_timer_multiple_actions(self) -> None: |
| 276 | + timer = AggregatedTimer() |
| 277 | + |
| 278 | + actions = ["dataloader", "forward", "backward", "optimizer"] |
| 279 | + call_counts = [10, 10, 10, 5] |
| 280 | + |
| 281 | + for action, count in zip(actions, call_counts): |
| 282 | + for _ in range(count): |
| 283 | + with timer.time(action): |
| 284 | + time.sleep(0.001) |
| 285 | + |
| 286 | + # Check that all actions are recorded |
| 287 | + self.assertEqual(len(timer._aggregate_stats), 4) |
| 288 | + |
| 289 | + for action, expected_count in zip(actions, call_counts): |
| 290 | + self.assertEqual(timer._aggregate_stats[action].num_calls, expected_count) |
| 291 | + |
| 292 | + # Check report |
| 293 | + report = timer._make_report() |
| 294 | + self.assertEqual(len(report.timed_action_stats), 4) |
| 295 | + self.assertEqual(report.total_calls, sum(call_counts)) |
| 296 | + |
| 297 | + total_percentage = sum( |
| 298 | + stats.percentage_of_total_time for stats in report.timed_action_stats |
| 299 | + ) |
| 300 | + self.assert_within_tolerance(total_percentage, 100.0, 1) |
| 301 | + |
238 | 302 |
|
239 | 303 | class FullSyncPeriodicTimerTest(unittest.TestCase):
|
240 | 304 | @classmethod
|
|
0 commit comments