Skip to content

Commit 23191d4

Browse files
jaconeyfacebook-github-bot
authored andcommitted
Add plugin interface and add GPU mem snapshot as an example (#777)
Summary: Pull Request resolved: #777 WTTS. * Add basic `Plugin` interface to extend bulk gen functionality ``` class Plugin(ABC): def step(self) -> None: ... def shutdown(self) -> None ... ``` - Add GPU mem snapshot as an example. - Add `is_started` status to TNT `MemorySnapshotProfiler` and expose `log_memory_snapshot` Reviewed By: skcoirz Differential Revision: D55724465 fbshipit-source-id: 87226f9be2da801119d575a9b43a48109aa86a82
1 parent 4c90a5f commit 23191d4

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torchtnt/utils/memory_snapshot_profiler.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(
131131
)
132132

133133
self.step_num: int = 0
134+
self.is_started: bool = False
134135

135136
if not is_torch_version_geq_2_0():
136137
raise RuntimeError("CUDA memory snapshot requires torch>=2.0")
@@ -146,20 +147,26 @@ def __init__(
146147
)
147148

148149
def start(self) -> None:
150+
if self.is_started:
151+
return
149152
if not torch.cuda.is_available():
150153
logger.warn("CUDA unavailable. Not recording memory history.")
151154
return
152155

153156
logger.info("Starting to record memory history.")
154157
torch.cuda.memory._record_memory_history(max_entries=self.params.max_entries)
158+
self.is_started = True
155159

156160
def stop(self) -> None:
161+
if not self.is_started:
162+
return
157163
if not torch.cuda.is_available():
158164
logger.warn("CUDA unavailable. Not recording memory history.")
159165
return
160166

161167
logger.info("Stopping recording memory history.")
162168
torch.cuda.memory._record_memory_history(enabled=None)
169+
self.is_started = False
163170

164171
def step(self) -> None:
165172
self.step_num += 1
@@ -169,7 +176,10 @@ def step(self) -> None:
169176
):
170177
self.start()
171178
if self.params.stop_step is not None and self.step_num == self.params.stop_step:
172-
log_memory_snapshot(
173-
output_dir=self.output_dir, file_prefix=f"step_{self.step_num}"
174-
)
179+
self.log_memory_snapshot()
175180
self.stop()
181+
182+
def log_memory_snapshot(self) -> None:
183+
log_memory_snapshot(
184+
output_dir=self.output_dir, file_prefix=f"step_{self.step_num}"
185+
)

0 commit comments

Comments
 (0)