Skip to content

Commit 614526a

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Separate memory_snapshot_profiler GPU tests (#765)
Summary: Pull Request resolved: #765 Reviewed By: galrotem Differential Revision: D55439083 fbshipit-source-id: 862eded7e67996574783aa81d6581582a45be02a
1 parent d956000 commit 614526a

File tree

2 files changed

+65
-43
lines changed

2 files changed

+65
-43
lines changed

tests/utils/test_memory_snapshot_profiler.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,63 +7,20 @@
77

88
# pyre-strict
99

10-
import os
1110
import tempfile
1211
import unittest
1312

14-
import torch
15-
from torchtnt.utils.device import get_device_from_env
1613
from torchtnt.utils.memory_snapshot_profiler import (
1714
MemorySnapshotParams,
1815
MemorySnapshotProfiler,
1916
)
20-
from torchtnt.utils.test_utils import skip_if_not_gpu
2117
from torchtnt.utils.version import is_torch_version_geq_2_0
2218

2319

2420
class MemorySnapshotProfilerTest(unittest.TestCase):
2521

2622
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
2723

28-
@skip_if_not_gpu
29-
@unittest.skipUnless(
30-
condition=torch_version_geq_2_0,
31-
reason="This test needs changes from PyTorch 2.0 to run.",
32-
)
33-
def test_stop_step(self) -> None:
34-
"""Test that a memory snapshot is saved when stop_step is reached."""
35-
with tempfile.TemporaryDirectory() as temp_dir:
36-
memory_snapshot_profiler = MemorySnapshotProfiler(
37-
output_dir=temp_dir,
38-
memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2),
39-
)
40-
41-
# initialize device & allocate memory for tensors
42-
device = get_device_from_env()
43-
a = torch.rand((1024, 1024), device=device)
44-
b = torch.rand((1024, 1024), device=device)
45-
_ = (a + b) * (a - b)
46-
47-
memory_snapshot_profiler.step()
48-
49-
# Check if the corresponding files exist
50-
save_dir = os.path.join(temp_dir, "step_2_rank0")
51-
52-
pickle_dump_path = os.path.join(save_dir, "snapshot.pickle")
53-
trace_path = os.path.join(save_dir, "trace_plot.html")
54-
segment_plot_path = os.path.join(save_dir, "segment_plot.html")
55-
56-
# after first step files do not exist
57-
self.assertFalse(os.path.exists(pickle_dump_path))
58-
self.assertFalse(os.path.exists(trace_path))
59-
self.assertFalse(os.path.exists(segment_plot_path))
60-
61-
# after second step stop_step is reached and files should exist
62-
memory_snapshot_profiler.step()
63-
self.assertTrue(os.path.exists(pickle_dump_path))
64-
self.assertTrue(os.path.exists(trace_path))
65-
self.assertTrue(os.path.exists(segment_plot_path))
66-
6724
@unittest.skipUnless(
6825
condition=torch_version_geq_2_0,
6926
reason="This test needs changes from PyTorch 2.0 to run.",
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import os
11+
import tempfile
12+
import unittest
13+
14+
import torch
15+
from torchtnt.utils.device import get_device_from_env
16+
from torchtnt.utils.memory_snapshot_profiler import (
17+
MemorySnapshotParams,
18+
MemorySnapshotProfiler,
19+
)
20+
from torchtnt.utils.test_utils import skip_if_not_gpu
21+
from torchtnt.utils.version import is_torch_version_geq_2_0
22+
23+
24+
class MemorySnapshotProfilerGPUTest(unittest.TestCase):
25+
26+
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
27+
28+
@skip_if_not_gpu
29+
@unittest.skipUnless(
30+
condition=torch_version_geq_2_0,
31+
reason="This test needs changes from PyTorch 2.0 to run.",
32+
)
33+
def test_stop_step(self) -> None:
34+
"""Test that a memory snapshot is saved when stop_step is reached."""
35+
with tempfile.TemporaryDirectory() as temp_dir:
36+
memory_snapshot_profiler = MemorySnapshotProfiler(
37+
output_dir=temp_dir,
38+
memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2),
39+
)
40+
41+
# initialize device & allocate memory for tensors
42+
device = get_device_from_env()
43+
a = torch.rand((1024, 1024), device=device)
44+
b = torch.rand((1024, 1024), device=device)
45+
_ = (a + b) * (a - b)
46+
47+
memory_snapshot_profiler.step()
48+
49+
# Check if the corresponding files exist
50+
save_dir = os.path.join(temp_dir, "step_2_rank0")
51+
52+
pickle_dump_path = os.path.join(save_dir, "snapshot.pickle")
53+
trace_path = os.path.join(save_dir, "trace_plot.html")
54+
segment_plot_path = os.path.join(save_dir, "segment_plot.html")
55+
56+
# after first step files do not exist
57+
self.assertFalse(os.path.exists(pickle_dump_path))
58+
self.assertFalse(os.path.exists(trace_path))
59+
self.assertFalse(os.path.exists(segment_plot_path))
60+
61+
# after second step stop_step is reached and files should exist
62+
memory_snapshot_profiler.step()
63+
self.assertTrue(os.path.exists(pickle_dump_path))
64+
self.assertTrue(os.path.exists(trace_path))
65+
self.assertTrue(os.path.exists(segment_plot_path))

0 commit comments

Comments
 (0)