|
13 | 13 | from datetime import timedelta
|
14 | 14 | from random import random
|
15 | 15 | from unittest import mock
|
16 |
| -from unittest.mock import Mock, patch |
17 | 16 |
|
18 |
| -import torch |
19 | 17 | import torch.distributed as dist
|
20 | 18 | from pyre_extensions import none_throws
|
21 | 19 | from torchtnt.utils.distributed import spawn_multi_process
|
22 |
| -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu |
| 20 | +from torchtnt.utils.test_utils import skip_if_not_distributed |
23 | 21 | from torchtnt.utils.timer import (
|
24 | 22 | BoundedTimer,
|
25 | 23 | FullSyncPeriodicTimer,
|
@@ -68,8 +66,7 @@ def test_timer_context_manager_size_bound(self) -> None:
|
68 | 66 | UPPER_BOUND,
|
69 | 67 | )
|
70 | 68 |
|
71 |
| - @patch("torch.cuda.synchronize") |
72 |
| - def test_timer_context_manager(self, _) -> None: |
| 69 | + def test_timer_context_manager(self) -> None: |
73 | 70 | """Test the context manager in the timer class"""
|
74 | 71 |
|
75 | 72 | # Generate 3 intervals between 0.5 and 2 seconds
|
@@ -103,24 +100,6 @@ def test_timer_context_manager(self, _) -> None:
|
103 | 100 | timer.recorded_durations["action_4"][0], intervals[2]
|
104 | 101 | )
|
105 | 102 |
|
106 |
| - @skip_if_not_gpu |
107 |
| - @patch("torch.cuda.synchronize") |
108 |
| - def test_timer_synchronize(self, mock_synchornize: Mock) -> None: |
109 |
| - """Make sure that torch.cuda.synchronize() is called when GPU is present.""" |
110 |
| - |
111 |
| - start_event = torch.cuda.Event(enable_timing=True) |
112 |
| - end_event = torch.cuda.Event(enable_timing=True) |
113 |
| - timer = Timer() |
114 |
| - |
115 |
| - # Do not explicitly call synchronize, timer must call it for test to pass. |
116 |
| - |
117 |
| - with timer.time("action_1"): |
118 |
| - start_event.record() |
119 |
| - time.sleep(0.5) |
120 |
| - end_event.record() |
121 |
| - |
122 |
| - self.assertEqual(mock_synchornize.call_count, 2) |
123 |
| - |
124 | 103 | def test_get_timer_summary(self) -> None:
|
125 | 104 | """Test the get_timer_summary function"""
|
126 | 105 |
|
|
0 commit comments