-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathfuse_multiple_checkpoints.py
More file actions
110 lines (88 loc) · 3.63 KB
/
fuse_multiple_checkpoints.py
File metadata and controls
110 lines (88 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import glob
import os
from typing import List, Tuple
import click
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from tqdm import tqdm
def find_deepspeed_checkpoints(root_dir: str, pattern: str = "quarters*.ckpt") -> List[str]:
"""Find all DeepSpeed checkpoint directories matching the pattern."""
# Search recursively in all subdirectories
search_pattern = os.path.join(root_dir, "**/", pattern)
checkpoint_dirs = glob.glob(search_pattern, recursive=True)
# Filter to get only directories that don't have corresponding fused versions
non_fused_checkpoints = []
for dir_path in checkpoint_dirs:
if not os.path.isdir(dir_path):
continue
fused_path = dir_path.rsplit(".ckpt", 1)[0] + "_fused.pt"
if not os.path.exists(fused_path):
non_fused_checkpoints.append(dir_path)
return non_fused_checkpoints
def format_path(path: str, max_length: int = 100) -> str:
"""Format path for display, showing ellipsis in middle if too long."""
if len(path) <= max_length:
return path
head, tail = os.path.split(path)
middle_len = max_length - len(tail) - 5 # 5 for '/.../'
if middle_len < 10: # if too short, just truncate from the start
return "..." + path[-(max_length - 3) :]
return f"{head[:middle_len // 2]}...{head[-(middle_len // 2):]}/{tail}"
def fuse_checkpoint(checkpoint_dir: str) -> Tuple[bool, str | None]:
"""Fuse a single DeepSpeed checkpoint directory."""
output_path = checkpoint_dir.rsplit(".ckpt", 1)[0] + "_fused.pt"
try:
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_path)
return True, output_path
except Exception as e:
print(f"\nError fusing {format_path(checkpoint_dir)}: {str(e)}")
return False, None
@click.command()
@click.argument("root_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True))
@click.option(
"--pattern",
"-p",
default="quarters_epoch=*.ckpt",
help="Pattern to match checkpoint directories (default: quarters_epoch=*.ckpt)",
)
def main(root_dir: str, pattern: str) -> None:
"""
Fuse DeepSpeed checkpoints starting from ROOT_DIR.
Searches recursively for checkpoint directories matching the pattern
and converts them to consolidated PyTorch state dictionaries.
"""
root_dir = os.path.abspath(root_dir)
print(f"Searching for checkpoints in: {root_dir}")
print(f"Using pattern: {pattern}")
# Find all matching checkpoints
checkpoints = find_deepspeed_checkpoints(root_dir, pattern)
total_found = len(checkpoints)
if not checkpoints:
print("\nNo non-fused checkpoints found.")
return
print(f"\n{total_found} non-fused checkpoint(s) found:")
for cp in checkpoints:
print(f"- {format_path(cp)}")
# Process each checkpoint with progress bar
successful = 0
failed = 0
processed = []
print("\nStarting fusion process...")
for cp in tqdm(checkpoints, desc="Fusing checkpoints", unit="ckpt"):
success, output_path = fuse_checkpoint(cp)
if success:
successful += 1
processed.append((cp, output_path))
else:
failed += 1
# Print summary
print("\nFusion Complete!")
print(f"Successfully fused: {successful}")
print(f"Failed to fuse: {failed}")
print(f"Total processed: {successful + failed}")
if successful > 0:
print("\nSuccessfully processed checkpoints:")
for src, dst in processed:
print(f"\nFrom: {format_path(src)}")
print(f"To: {format_path(dst)}")
if __name__ == "__main__":
main()