Skip to content

Commit ce5a73e

Browse files
committed
Fix some zero-overhead checkpointing bugs
Summary: 1. The original code does not utitlize `share_memory=True`, this may cause incorrectness or slowdown. 2. The original code does not pass the correct cpu-offloaded state_dict, which can cause another slowdown or incorrect saving. ghstack-source-id: c04c634 Pull Request resolved: #602
1 parent 9ccc161 commit ce5a73e

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

torchtitan/checkpoint.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def __init__(
271271
self.mp.start()
272272
self.cpu_offload_state_dict = None
273273
self.staging = False
274-
self.staging_state_dict = None
275274
self.staging_id = None
276275
self.staging_stream = torch.cuda.Stream()
277276
else:
@@ -384,7 +383,7 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
384383
if self.cpu_offload_state_dict is None:
385384
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
386385
self.cpu_offload_state_dict = _create_cpu_state_dict(
387-
state_dict, pin_memory=True
386+
state_dict, pin_memory=True, share_memory=True
388387
)
389388

390389
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
@@ -395,7 +394,6 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
395394
non_blocking=True,
396395
)
397396
self.staging = True
398-
self.staging_state_dict = state_dict
399397
self.staging_id = checkpoint_id
400398

401399
def save(self, curr_step: int, force: bool = False) -> None:
@@ -435,12 +433,19 @@ def maybe_wait_for_staging(self) -> None:
435433
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
436434
and self.staging
437435
):
438-
logger.debug(f"Waiting for staging, {time.monotonic()=:.2f}.")
439-
self.staging_stream.synchronize()
440-
logger.debug(
441-
f"Sending the state dict to the background process, {time.monotonic()=:.2f}."
442-
)
443-
self.mp_queue_send.put((self.staging_state_dict, self.staging_id))
436+
if not self.staging_stream.query():
437+
self.staging_stream.synchronize()
438+
439+
def sync_func():
440+
self.mp_queue_send.put_nowait(
441+
(self.cpu_offload_state_dict, self.staging_id)
442+
)
443+
444+
# This may be a faster way to do zero-overhead checkpointing staging
445+
# checkpointing but we need more thorough investigation before
446+
# swithing to this method.
447+
# self.my_thread = threading.Thread(target=func).start()
448+
sync_func()
444449
self.staging = False
445450

446451
def load(self, step: int = -1) -> bool:

0 commit comments

Comments
 (0)