Skip to content

Commit d956000

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Separate OOM GPU test into a dedicated file (#766)
Summary: Pull Request resolved: #766 Reviewed By: galrotem Differential Revision: D55440455 fbshipit-source-id: 5d65824716b80c5c941c13efd4376bdca50e496d
1 parent 8e93a51 commit d956000

File tree

2 files changed

+52
-38
lines changed

2 files changed

+52
-38
lines changed

tests/utils/test_oom.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,14 @@
77

88
# pyre-strict
99

10-
import os
11-
import tempfile
1210
import unittest
1311

14-
import torch
15-
from torchtnt.utils.device import get_device_from_env
1612
from torchtnt.utils.oom import (
1713
_bytes_to_mb_gb,
1814
is_out_of_cpu_memory,
1915
is_out_of_cuda_memory,
2016
is_out_of_memory_error,
21-
log_memory_snapshot,
2217
)
23-
from torchtnt.utils.test_utils import skip_if_not_gpu
24-
from torchtnt.utils.version import is_torch_version_geq_2_0
2518

2619

2720
class OomTest(unittest.TestCase):
@@ -56,37 +49,6 @@ def test_is_out_of_memory_error(self) -> None:
5649
not_oom_error = RuntimeError("RuntimeError: blah")
5750
self.assertFalse(is_out_of_memory_error(not_oom_error))
5851

59-
@skip_if_not_gpu
60-
@unittest.skipUnless(
61-
condition=bool(is_torch_version_geq_2_0()),
62-
reason="This test needs changes from PyTorch 2.0 to run.",
63-
)
64-
def test_log_memory_snapshot(self) -> None:
65-
with tempfile.TemporaryDirectory() as temp_dir:
66-
# Record history
67-
torch.cuda.memory._record_memory_history(enabled="all", max_entries=10000)
68-
69-
# initialize device & allocate memory for tensors
70-
device = get_device_from_env()
71-
a = torch.rand((1024, 1024), device=device)
72-
b = torch.rand((1024, 1024), device=device)
73-
_ = (a + b) * (a - b)
74-
75-
# save a snapshot
76-
log_memory_snapshot(temp_dir, "foo")
77-
78-
# Check if the corresponding files exist
79-
save_dir = os.path.join(temp_dir, "foo_rank0")
80-
81-
pickle_dump_path = os.path.join(save_dir, "snapshot.pickle")
82-
self.assertTrue(os.path.exists(pickle_dump_path))
83-
84-
trace_path = os.path.join(save_dir, "trace_plot.html")
85-
self.assertTrue(os.path.exists(trace_path))
86-
87-
segment_plot_path = os.path.join(save_dir, "segment_plot.html")
88-
self.assertTrue(os.path.exists(segment_plot_path))
89-
9052
def test_bytes_to_mb_gb(self) -> None:
9153
bytes_to_mb_test_cases = [
9254
(0, "0.0 MB"),

tests/utils/test_oom_gpu.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import os
11+
import tempfile
12+
import unittest
13+
14+
import torch
15+
from torchtnt.utils.device import get_device_from_env
16+
from torchtnt.utils.oom import log_memory_snapshot
17+
18+
from torchtnt.utils.test_utils import skip_if_not_gpu
19+
from torchtnt.utils.version import is_torch_version_geq_2_0
20+
21+
22+
class OomGPUTest(unittest.TestCase):
23+
@skip_if_not_gpu
24+
@unittest.skipUnless(
25+
condition=bool(is_torch_version_geq_2_0()),
26+
reason="This test needs changes from PyTorch 2.0 to run.",
27+
)
28+
def test_log_memory_snapshot(self) -> None:
29+
with tempfile.TemporaryDirectory() as temp_dir:
30+
# Record history
31+
torch.cuda.memory._record_memory_history(enabled="all", max_entries=10000)
32+
33+
# initialize device & allocate memory for tensors
34+
device = get_device_from_env()
35+
a = torch.rand((1024, 1024), device=device)
36+
b = torch.rand((1024, 1024), device=device)
37+
_ = (a + b) * (a - b)
38+
39+
# save a snapshot
40+
log_memory_snapshot(temp_dir, "foo")
41+
42+
# Check if the corresponding files exist
43+
save_dir = os.path.join(temp_dir, "foo_rank0")
44+
45+
pickle_dump_path = os.path.join(save_dir, "snapshot.pickle")
46+
self.assertTrue(os.path.exists(pickle_dump_path))
47+
48+
trace_path = os.path.join(save_dir, "trace_plot.html")
49+
self.assertTrue(os.path.exists(trace_path))
50+
51+
segment_plot_path = os.path.join(save_dir, "segment_plot.html")
52+
self.assertTrue(os.path.exists(segment_plot_path))

0 commit comments

Comments
 (0)