|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 |
| -import os |
11 |
| -import tempfile |
12 | 10 | import unittest
|
13 | 11 |
|
14 |
| -import torch |
15 |
| -from torchtnt.utils.device import get_device_from_env |
16 | 12 | from torchtnt.utils.oom import (
|
17 | 13 | _bytes_to_mb_gb,
|
18 | 14 | is_out_of_cpu_memory,
|
19 | 15 | is_out_of_cuda_memory,
|
20 | 16 | is_out_of_memory_error,
|
21 |
| - log_memory_snapshot, |
22 | 17 | )
|
23 |
| -from torchtnt.utils.test_utils import skip_if_not_gpu |
24 |
| -from torchtnt.utils.version import is_torch_version_geq_2_0 |
25 | 18 |
|
26 | 19 |
|
27 | 20 | class OomTest(unittest.TestCase):
|
@@ -56,37 +49,6 @@ def test_is_out_of_memory_error(self) -> None:
|
56 | 49 | not_oom_error = RuntimeError("RuntimeError: blah")
|
57 | 50 | self.assertFalse(is_out_of_memory_error(not_oom_error))
|
58 | 51 |
|
59 |
| - @skip_if_not_gpu |
60 |
| - @unittest.skipUnless( |
61 |
| - condition=bool(is_torch_version_geq_2_0()), |
62 |
| - reason="This test needs changes from PyTorch 2.0 to run.", |
63 |
| - ) |
64 |
| - def test_log_memory_snapshot(self) -> None: |
65 |
| - with tempfile.TemporaryDirectory() as temp_dir: |
66 |
| - # Record history |
67 |
| - torch.cuda.memory._record_memory_history(enabled="all", max_entries=10000) |
68 |
| - |
69 |
| - # initialize device & allocate memory for tensors |
70 |
| - device = get_device_from_env() |
71 |
| - a = torch.rand((1024, 1024), device=device) |
72 |
| - b = torch.rand((1024, 1024), device=device) |
73 |
| - _ = (a + b) * (a - b) |
74 |
| - |
75 |
| - # save a snapshot |
76 |
| - log_memory_snapshot(temp_dir, "foo") |
77 |
| - |
78 |
| - # Check if the corresponding files exist |
79 |
| - save_dir = os.path.join(temp_dir, "foo_rank0") |
80 |
| - |
81 |
| - pickle_dump_path = os.path.join(save_dir, "snapshot.pickle") |
82 |
| - self.assertTrue(os.path.exists(pickle_dump_path)) |
83 |
| - |
84 |
| - trace_path = os.path.join(save_dir, "trace_plot.html") |
85 |
| - self.assertTrue(os.path.exists(trace_path)) |
86 |
| - |
87 |
| - segment_plot_path = os.path.join(save_dir, "segment_plot.html") |
88 |
| - self.assertTrue(os.path.exists(segment_plot_path)) |
89 |
| - |
90 | 52 | def test_bytes_to_mb_gb(self) -> None:
|
91 | 53 | bytes_to_mb_test_cases = [
|
92 | 54 | (0, "0.0 MB"),
|
|
0 commit comments