File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -131,6 +131,7 @@ def __init__(
131
131
)
132
132
133
133
self .step_num : int = 0
134
+ self .is_started : bool = False
134
135
135
136
if not is_torch_version_geq_2_0 ():
136
137
raise RuntimeError ("CUDA memory snapshot requires torch>=2.0" )
@@ -146,20 +147,26 @@ def __init__(
146
147
)
147
148
148
149
def start (self ) -> None :
150
+ if self .is_started :
151
+ return
149
152
if not torch .cuda .is_available ():
150
153
logger .warn ("CUDA unavailable. Not recording memory history." )
151
154
return
152
155
153
156
logger .info ("Starting to record memory history." )
154
157
torch .cuda .memory ._record_memory_history (max_entries = self .params .max_entries )
158
+ self .is_started = True
155
159
156
160
def stop (self ) -> None :
161
+ if not self .is_started :
162
+ return
157
163
if not torch .cuda .is_available ():
158
164
logger .warn ("CUDA unavailable. Not recording memory history." )
159
165
return
160
166
161
167
logger .info ("Stopping recording memory history." )
162
168
torch .cuda .memory ._record_memory_history (enabled = None )
169
+ self .is_started = False
163
170
164
171
def step (self ) -> None :
165
172
self .step_num += 1
@@ -169,7 +176,10 @@ def step(self) -> None:
169
176
):
170
177
self .start ()
171
178
if self .params .stop_step is not None and self .step_num == self .params .stop_step :
172
- log_memory_snapshot (
173
- output_dir = self .output_dir , file_prefix = f"step_{ self .step_num } "
174
- )
179
+ self .log_memory_snapshot ()
175
180
self .stop ()
181
+
182
+ def log_memory_snapshot (self ) -> None :
183
+ log_memory_snapshot (
184
+ output_dir = self .output_dir , file_prefix = f"step_{ self .step_num } "
185
+ )
You can’t perform that action at this time.
0 commit comments