Skip to content

Commit f16b245

Browse files
authored
[ckpt] fix: prevent data loss when max_ckpt_to_keep=1 (#4873)
### What does this PR do? Fixes a data loss bug when `max_ckpt_to_keep=1`: the old checkpoint was deleted **before** the new save completed. If the save fails (disk full, crash, etc.), all checkpoints are lost. The fix ensures the previous checkpoint is preserved until the new one is successfully saved. Also consolidates duplicated cleanup logic from FSDP/Megatron managers into `BaseCheckpointManager`. **Trade-off:** With `max_ckpt_to_keep=1`, there is now temporary additional storage overhead during saves — two checkpoints exist briefly until the old one is deleted after the new save completes. This is the expected behavior to guarantee data safety. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pulls?q=is%3Apr+max_ckpt_to_keep - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) ### Test Added CPU unit tests in `tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py` covering: - `max_ckpt_to_keep=1` preserves checkpoint before save (regression test) - `max_ckpt_to_keep=1` deletes old checkpoint after successful save - `max_ckpt_to_keep=2` keeps safety buffer - `max_ckpt_to_keep=0` (unlimited) keeps all - Full save cycle simulation ### API and Usage Example No API changes. Existing `max_ckpt_to_keep` parameter now works correctly. ### Design & Code Changes **New methods in `BaseCheckpointManager`:** - `ensure_checkpoint_capacity(max_ckpt_to_keep)` — called before save, keeps `max - 1` checkpoints as safety buffer (does nothing when `max=1`) - `register_checkpoint(new_path, max_ckpt_to_keep)` — called after save, registers path and enforces retention limit **Changes to subclasses:** - `FSDPCheckpointManager`: replaced inline cleanup logic with calls to base class methods - `MegatronCheckpointManager`: same refactor for both sync and async save paths ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). *(N/A - no user-facing changes)* - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
1 parent bc73797 commit f16b245

File tree

4 files changed

+206
-64
lines changed

4 files changed

+206
-64
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import shutil
17+
import tempfile
18+
19+
import pytest
20+
21+
22+
class TestCheckpointCleanupLogic:
23+
"""Tests for checkpoint cleanup methods in BaseCheckpointManager."""
24+
25+
@pytest.fixture(autouse=True)
26+
def setup(self):
27+
"""Set up test fixtures."""
28+
self.test_dir = tempfile.mkdtemp()
29+
yield
30+
shutil.rmtree(self.test_dir, ignore_errors=True)
31+
32+
@pytest.fixture
33+
def manager(self, monkeypatch):
34+
"""Create a minimal BaseCheckpointManager for testing."""
35+
import torch.distributed
36+
37+
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0)
38+
monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 1)
39+
40+
from verl.utils.checkpoint.checkpoint_manager import BaseCheckpointManager
41+
42+
class MockModel:
43+
pass
44+
45+
class MockOptimizer:
46+
pass
47+
48+
return BaseCheckpointManager(
49+
model=MockModel(),
50+
optimizer=MockOptimizer(),
51+
lr_scheduler=None,
52+
processing_class=None,
53+
checkpoint_config=None,
54+
)
55+
56+
def _create_checkpoint_dir(self, step: int) -> str:
57+
"""Create a mock checkpoint directory."""
58+
path = os.path.join(self.test_dir, f"global_step_{step}")
59+
os.makedirs(path, exist_ok=True)
60+
with open(os.path.join(path, "checkpoint.txt"), "w") as f:
61+
f.write(f"step={step}")
62+
return path
63+
64+
def test_max_ckpt_1_preserves_existing_before_save(self, manager):
65+
"""
66+
Regression test: max_ckpt_to_keep=1 must NOT delete existing checkpoint before save.
67+
"""
68+
ckpt_100 = self._create_checkpoint_dir(100)
69+
manager.previous_saved_paths = [ckpt_100]
70+
71+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=1)
72+
73+
assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
74+
assert manager.previous_saved_paths == [ckpt_100]
75+
76+
def test_max_ckpt_1_deletes_old_after_save(self, manager):
77+
"""After save succeeds, old checkpoint should be deleted."""
78+
ckpt_100 = self._create_checkpoint_dir(100)
79+
manager.previous_saved_paths = [ckpt_100]
80+
81+
ckpt_200 = self._create_checkpoint_dir(200)
82+
manager.register_checkpoint(ckpt_200, max_ckpt_to_keep=1)
83+
84+
assert not os.path.exists(ckpt_100)
85+
assert os.path.exists(ckpt_200)
86+
assert manager.previous_saved_paths == [ckpt_200]
87+
88+
def test_max_ckpt_2_keeps_one_before_save(self, manager):
89+
"""With max_ckpt_to_keep=2, pre-save cleanup keeps 1 checkpoint."""
90+
ckpt_100 = self._create_checkpoint_dir(100)
91+
ckpt_200 = self._create_checkpoint_dir(200)
92+
manager.previous_saved_paths = [ckpt_100, ckpt_200]
93+
94+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=2)
95+
96+
assert not os.path.exists(ckpt_100)
97+
assert os.path.exists(ckpt_200)
98+
assert len(manager.previous_saved_paths) == 1
99+
100+
def test_max_ckpt_0_keeps_all(self, manager):
101+
"""max_ckpt_to_keep=0 means unlimited - no deletions."""
102+
ckpt_100 = self._create_checkpoint_dir(100)
103+
ckpt_200 = self._create_checkpoint_dir(200)
104+
manager.previous_saved_paths = [ckpt_100, ckpt_200]
105+
106+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=0)
107+
ckpt_300 = self._create_checkpoint_dir(300)
108+
manager.register_checkpoint(ckpt_300, max_ckpt_to_keep=0)
109+
110+
assert os.path.exists(ckpt_100)
111+
assert os.path.exists(ckpt_200)
112+
assert os.path.exists(ckpt_300)
113+
assert len(manager.previous_saved_paths) == 3
114+
115+
def test_full_save_cycle_max_ckpt_1(self, manager):
116+
"""Simulate multiple save cycles with max_ckpt_to_keep=1."""
117+
# First save
118+
manager.ensure_checkpoint_capacity(1)
119+
ckpt_100 = self._create_checkpoint_dir(100)
120+
manager.register_checkpoint(ckpt_100, 1)
121+
assert manager.previous_saved_paths == [ckpt_100]
122+
123+
# Second save - existing checkpoint must survive pre-save
124+
manager.ensure_checkpoint_capacity(1)
125+
assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
126+
127+
ckpt_200 = self._create_checkpoint_dir(200)
128+
manager.register_checkpoint(ckpt_200, 1)
129+
assert not os.path.exists(ckpt_100)
130+
assert manager.previous_saved_paths == [ckpt_200]
131+
132+
# Third save
133+
manager.ensure_checkpoint_capacity(1)
134+
assert os.path.exists(ckpt_200), "Bug: checkpoint deleted before save!"
135+
136+
ckpt_300 = self._create_checkpoint_dir(300)
137+
manager.register_checkpoint(ckpt_300, 1)
138+
assert not os.path.exists(ckpt_200)
139+
assert manager.previous_saved_paths == [ckpt_300]

verl/utils/checkpoint/checkpoint_manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,36 @@ def remove_previous_save_local_path(self, path):
141141
continue
142142
shutil.rmtree(abs_path, ignore_errors=True)
143143

144+
def ensure_checkpoint_capacity(self, max_ckpt_to_keep: int):
145+
"""
146+
Remove old checkpoints to make room for a new one, keeping a safety buffer.
147+
148+
With max_ckpt_to_keep=1, this does nothing - we keep the existing checkpoint
149+
until the new save completes successfully (handled by register_checkpoint).
150+
For max_ckpt_to_keep >= 2, we keep (max_ckpt_to_keep - 1) checkpoints before save.
151+
"""
152+
if not (max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 1):
153+
return
154+
if len(self.previous_saved_paths) >= max_ckpt_to_keep:
155+
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
156+
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
157+
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
158+
159+
def register_checkpoint(self, new_path: str, max_ckpt_to_keep: int):
160+
"""
161+
Register a successfully saved checkpoint and enforce retention limit.
162+
163+
Adds the new checkpoint path to tracking and removes excess old
164+
checkpoints beyond max_ckpt_to_keep.
165+
"""
166+
self.previous_saved_paths.append(new_path)
167+
if not (max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0):
168+
return
169+
if len(self.previous_saved_paths) > max_ckpt_to_keep:
170+
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep
171+
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
172+
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
173+
144174
@staticmethod
145175
def get_rng_state():
146176
rng_state = {

verl/utils/checkpoint/fsdp_checkpoint_manager.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
201201
# record the previous global step
202202
self.previous_global_step = global_step
203203

204-
# remove previous local_path, only rank 0 should do this
205-
if (
206-
self.rank == 0
207-
and max_ckpt_to_keep
208-
and isinstance(max_ckpt_to_keep, int)
209-
and max_ckpt_to_keep > 0
210-
and len(self.previous_saved_paths) >= max_ckpt_to_keep
211-
):
212-
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
213-
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
214-
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
204+
if self.rank == 0:
205+
self.ensure_checkpoint_capacity(max_ckpt_to_keep)
215206

216207
local_path = local_mkdir_safe(local_path)
217208
torch.distributed.barrier()
@@ -367,4 +358,5 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
367358
# wait for rank0 to dump hf_model to local
368359
torch.distributed.barrier()
369360

370-
self.previous_saved_paths.append(local_path)
361+
if self.rank == 0:
362+
self.register_checkpoint(local_path, max_ckpt_to_keep)

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -414,17 +414,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
414414
# record the previous global step
415415
self.previous_global_step = global_step
416416

417-
# remove previous local_path
418-
if (
419-
not self.checkpoint_config.async_save
420-
and max_ckpt_to_keep
421-
and isinstance(max_ckpt_to_keep, int)
422-
and max_ckpt_to_keep > 0
423-
and len(self.previous_saved_paths) >= max_ckpt_to_keep
424-
):
425-
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
426-
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
427-
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
417+
if not self.checkpoint_config.async_save:
418+
self.ensure_checkpoint_capacity(max_ckpt_to_keep)
428419

429420
local_path = local_mkdir_safe(local_path)
430421
dist_checkpoint_path = get_dist_checkpoint_path(local_path)
@@ -646,46 +637,37 @@ def finalize_save_fn():
646637
hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)
647638

648639
# update latest_checkpointed_iteration.txt when async_save is True
649-
if not self.checkpoint_config.async_save:
650-
return
651-
652-
head_node = None
653-
nodes = api.list_nodes()
654-
for node in nodes:
655-
if node.is_head_node:
656-
head_node = node
657-
break
658-
659-
current_node_id = ray.get_runtime_context().get_node_id()
660-
ray_local_world_size = int(os.getenv("RAY_LOCAL_WORLD_SIZE", -1))
661-
if ray_local_world_size == -1:
662-
nnodes = int(os.getenv("NNODES", 1))
663-
ray_local_world_size = torch.distributed.get_world_size() / nnodes
664-
665-
if head_node is not None and head_node.node_id == current_node_id and self.rank % ray_local_world_size == 0:
666-
log_with_rank(
667-
f"Update latest_checkpointed_iteration.txt to step {global_step}",
668-
rank=self.rank,
669-
logger=logger,
670-
)
671-
local_latest_checkpointed_iteration = os.path.join(
672-
os.path.dirname(os.path.dirname(local_path)), "latest_checkpointed_iteration.txt"
673-
)
674-
with open(local_latest_checkpointed_iteration, "w") as f:
675-
f.write(str(global_step))
676-
677-
# remove previous local_path
678-
self.previous_saved_paths.append(local_path)
679-
680-
if (
681-
max_ckpt_to_keep
682-
and isinstance(max_ckpt_to_keep, int)
683-
and max_ckpt_to_keep > 0
684-
and len(self.previous_saved_paths) > max_ckpt_to_keep
685-
):
686-
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep
687-
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
688-
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
640+
if self.checkpoint_config.async_save:
641+
head_node = None
642+
nodes = api.list_nodes()
643+
for node in nodes:
644+
if node.is_head_node:
645+
head_node = node
646+
break
647+
648+
current_node_id = ray.get_runtime_context().get_node_id()
649+
ray_local_world_size = int(os.getenv("RAY_LOCAL_WORLD_SIZE", -1))
650+
if ray_local_world_size == -1:
651+
nnodes = int(os.getenv("NNODES", 1))
652+
ray_local_world_size = torch.distributed.get_world_size() / nnodes
653+
654+
if (
655+
head_node is not None
656+
and head_node.node_id == current_node_id
657+
and self.rank % ray_local_world_size == 0
658+
):
659+
log_with_rank(
660+
f"Update latest_checkpointed_iteration.txt to step {global_step}",
661+
rank=self.rank,
662+
logger=logger,
663+
)
664+
local_latest_checkpointed_iteration = os.path.join(
665+
os.path.dirname(os.path.dirname(local_path)), "latest_checkpointed_iteration.txt"
666+
)
667+
with open(local_latest_checkpointed_iteration, "w") as f:
668+
f.write(str(global_step))
669+
670+
self.register_checkpoint(local_path, max_ckpt_to_keep)
689671

690672
if self.checkpoint_config.async_save:
691673
assert async_save_request is not None, "Async save request should not be None when using async save."
@@ -695,4 +677,3 @@ def finalize_save_fn():
695677
async_calls.schedule_async_request(async_save_request)
696678
else:
697679
finalize_save_fn()
698-
self.previous_saved_paths.append(local_path)

0 commit comments

Comments
 (0)