Skip to content

Commit 94344e2

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move test_timer GPU test to a separate file (#771)
Summary: Pull Request resolved: #771 Reviewed By: JKSenthil Differential Revision: D55490411 fbshipit-source-id: e836033997ed30da4461eb530ac21993bb2d3022
1 parent 51e1485 commit 94344e2

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

tests/utils/test_timer.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
from datetime import timedelta
1414
from random import random
1515
from unittest import mock
16-
from unittest.mock import Mock, patch
1716

18-
import torch
1917
import torch.distributed as dist
2018
from pyre_extensions import none_throws
2119
from torchtnt.utils.distributed import spawn_multi_process
22-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
20+
from torchtnt.utils.test_utils import skip_if_not_distributed
2321
from torchtnt.utils.timer import (
2422
BoundedTimer,
2523
FullSyncPeriodicTimer,
@@ -68,8 +66,7 @@ def test_timer_context_manager_size_bound(self) -> None:
6866
UPPER_BOUND,
6967
)
7068

71-
@patch("torch.cuda.synchronize")
72-
def test_timer_context_manager(self, _) -> None:
69+
def test_timer_context_manager(self) -> None:
7370
"""Test the context manager in the timer class"""
7471

7572
# Generate 3 intervals between 0.5 and 2 seconds
@@ -103,24 +100,6 @@ def test_timer_context_manager(self, _) -> None:
103100
timer.recorded_durations["action_4"][0], intervals[2]
104101
)
105102

106-
@skip_if_not_gpu
107-
@patch("torch.cuda.synchronize")
108-
def test_timer_synchronize(self, mock_synchornize: Mock) -> None:
109-
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""
110-
111-
start_event = torch.cuda.Event(enable_timing=True)
112-
end_event = torch.cuda.Event(enable_timing=True)
113-
timer = Timer()
114-
115-
# Do not explicitly call synchronize, timer must call it for test to pass.
116-
117-
with timer.time("action_1"):
118-
start_event.record()
119-
time.sleep(0.5)
120-
end_event.record()
121-
122-
self.assertEqual(mock_synchornize.call_count, 2)
123-
124103
def test_get_timer_summary(self) -> None:
125104
"""Test the get_timer_summary function"""
126105

tests/utils/test_timer_gpu.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import time
11+
import unittest
12+
from unittest.mock import MagicMock, patch
13+
14+
import torch
15+
from torchtnt.utils.test_utils import skip_if_not_gpu
16+
from torchtnt.utils.timer import Timer
17+
18+
19+
class TimerGPUTest(unittest.TestCase):
20+
@skip_if_not_gpu
21+
@patch("torch.cuda.synchronize")
22+
def test_timer_synchronize(self, mock_synchronize: MagicMock) -> None:
23+
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""
24+
25+
start_event = torch.cuda.Event(enable_timing=True)
26+
end_event = torch.cuda.Event(enable_timing=True)
27+
timer = Timer()
28+
29+
# Do not explicitly call synchronize, timer must call it for test to pass.
30+
31+
with timer.time("action_1"):
32+
start_event.record()
33+
time.sleep(0.5)
34+
end_event.record()
35+
36+
self.assertEqual(mock_synchronize.call_count, 2)

0 commit comments

Comments
 (0)