Skip to content

Commit a026d82

Browse files
committed
chore: Pipeline.checkpoint_dir -> Pipeline.run_dir
1 parent 7ab8571 commit a026d82

File tree

9 files changed

+112
-120
lines changed

9 files changed

+112
-120
lines changed

.github/actions/setup_environment/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: Setup CI Environment
22
inputs:
33
python-version:
4-
default: "3.10"
4+
default: "3.11"
55
type: string
66

77
runs:

dmlcloud/core/callbacks.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..git import git_diff
1616
from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector
1717
from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout
18-
from . import logging as dml_logging
18+
from . import checkpoint as dml_checkpoint, logging as dml_logging
1919
from .distributed import all_gather_object, is_root
2020

2121

@@ -302,27 +302,27 @@ class CheckpointCallback(Callback):
302302
Creates the checkpoint directory and optionally setups io redirection.
303303
"""
304304

305-
def __init__(self, root_path: Union[str, Path], redirect_io: bool = True):
305+
def __init__(self, run_dir: Union[str, Path], redirect_io: bool = True):
306306
"""
307-
Initialize the callback with the given root path.
307+
Initialize the callback with the given path.
308308
309309
Args:
310-
root_path (Union[str, Path]): The root path where the checkpoint directory will be created.
311-
redirect_io (bool, optional): Whether to redirect the IO to a file. Defaults to True.
310+
run_dir: The path to the checkpoint directory.
311+
redirect_io: Whether to redirect the IO to a file. Defaults to True.
312312
"""
313-
self.root_path = Path(root_path)
313+
self.run_dir = Path(run_dir)
314314
self.redirect_io = redirect_io
315315
self.io_redirector = None
316316

317317
def pre_run(self, pipe: 'Pipeline'):
318-
if not pipe.checkpoint_dir.is_valid:
319-
pipe.checkpoint_dir.create()
320-
pipe.checkpoint_dir.save_config(pipe.config)
318+
if not dml_checkpoint.is_valid_checkpoint_dir(self.run_dir):
319+
dml_checkpoint.create_checkpoint_dir(self.run_dir)
320+
dml_checkpoint.save_config(pipe.config, self.run_dir)
321321

322-
self.io_redirector = IORedirector(pipe.checkpoint_dir.log_file)
322+
self.io_redirector = IORedirector(pipe.run_dir / 'log.txt')
323323
self.io_redirector.install()
324324

325-
with open(pipe.checkpoint_dir.path / "environment.txt", 'w') as f:
325+
with open(pipe.run_dir / "environment.txt", 'w') as f:
326326
for k, v in os.environ.items():
327327
f.write(f"{k}={v}\n")
328328

@@ -479,7 +479,7 @@ class DiagnosticsCallback(Callback):
479479
"""
480480

481481
def pre_run(self, pipe):
482-
header = '\n' + experiment_header(pipe.name, pipe.checkpoint_dir, pipe.start_time)
482+
header = '\n' + experiment_header(pipe.name, pipe.run_dir, pipe.start_time)
483483
dml_logging.info(header)
484484

485485
diagnostics = general_diagnostics()
@@ -495,8 +495,8 @@ def post_stage(self, stage):
495495

496496
def post_run(self, pipe):
497497
dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})')
498-
if pipe.checkpointing_enabled:
499-
dml_logging.info(f'Outputs have been saved to {pipe.checkpoint_dir}')
498+
if pipe.has_checkpointing:
499+
dml_logging.info(f'Outputs have been saved to {pipe.run_dir}')
500500

501501

502502
class GitDiffCallback(Callback):
@@ -509,8 +509,8 @@ def pre_run(self, pipe):
509509
if diff is None:
510510
return
511511

512-
if pipe.checkpointing_enabled and is_root():
513-
self._save(pipe.checkpoint_dir.path / 'git_diff.txt', diff)
512+
if pipe.has_checkpointing and is_root():
513+
self._save(pipe.run_dir / 'git_diff.txt', diff)
514514

515515
msg = '* GIT-DIFF:\n'
516516
msg += '\n'.join(' ' + line for line in diff.splitlines())
@@ -558,8 +558,8 @@ def pre_run(self, pipe):
558558
msg += '\n'.join(f' - [{i}] {info_str}' for i, info_str in enumerate(info_strings))
559559
dml_logging.info(msg)
560560

561-
if pipe.checkpointing_enabled and is_root():
562-
self._save(pipe.checkpoint_dir.path / 'cuda_devices.json', all_devices)
561+
if pipe.has_checkpointing and is_root():
562+
self._save(pipe.run_dir / 'cuda_devices.json', all_devices)
563563

564564
def _save(self, path, all_devices):
565565
with open(path, 'w') as f:

dmlcloud/core/checkpoint.py

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import datetime
2-
import logging
32
import secrets
43
from pathlib import Path
54
from typing import Optional
@@ -9,6 +8,17 @@
98
from dmlcloud.slurm import slurm_job_id
109

1110

11+
__all__ = [
12+
'generate_checkpoint_path',
13+
'is_valid_checkpoint_dir',
14+
'create_checkpoint_dir',
15+
'find_slurm_checkpoint',
16+
'read_slurm_id',
17+
'save_config',
18+
'read_config',
19+
]
20+
21+
1222
def sanitize_filename(filename: str) -> str:
1323
return filename.replace('/', '_')
1424

@@ -34,90 +44,55 @@ def generate_checkpoint_path(
3444
return root / f'{name}-{dt}-{generate_id()}'
3545

3646

37-
def find_slurm_checkpoint(root: Path | str) -> Optional[Path]:
38-
root = Path(root)
39-
40-
job_id = slurm_job_id()
41-
if job_id is None:
42-
return None
43-
44-
for child in root.iterdir():
45-
if CheckpointDir(child).is_valid and CheckpointDir(child).slurm_job_id == job_id:
46-
return child
47-
48-
return None
49-
50-
51-
class CheckpointDir:
52-
def __init__(self, path: Path):
53-
self.path = Path(path).resolve()
54-
self.logger = logging.getLogger('dmlcloud')
47+
def is_valid_checkpoint_dir(path: Path) -> bool:
48+
if not path.exists() or not path.is_dir():
49+
return False
5550

56-
@property
57-
def config_file(self) -> Path:
58-
return self.path / 'config.yaml'
51+
if not (path / '.dmlcloud').exists():
52+
return False
5953

60-
@property
61-
def indicator_file(self) -> Path:
62-
return self.path / '.dmlcloud'
54+
return True
6355

64-
@property
65-
def log_file(self) -> Path:
66-
return self.path / 'log.txt'
6756

68-
@property
69-
def slurm_file(self) -> Path:
70-
return self.path / '.slurm-jobid'
57+
def create_checkpoint_dir(path: Path | str, name: Optional[str] = None) -> Path:
58+
path.mkdir(parents=True, exist_ok=True)
59+
(path / '.dmlcloud').touch()
60+
(path / 'log.txt').touch()
61+
if slurm_job_id() is not None:
62+
with open(path / '.slurm-jobid', 'w') as f:
63+
f.write(slurm_job_id())
7164

72-
@property
73-
def exists(self) -> bool:
74-
return self.path.exists()
7565

76-
@property
77-
def is_valid(self) -> bool:
78-
if not self.exists or not self.path.is_dir():
79-
return False
80-
81-
if not self.indicator_file.exists():
82-
return False
66+
def read_slurm_id(path: Path) -> Optional[str]:
67+
if is_valid_checkpoint_dir(path):
68+
return None
8369

84-
return True
70+
if not (path / '.slurm-jobid').exists():
71+
return None
8572

86-
@property
87-
def slurm_job_id(self) -> Optional[str]:
88-
if not self.slurm_file.exists():
89-
return None
73+
with open(path / '.slurm-jobid') as f:
74+
return f.read()
9075

91-
with open(self.slurm_file) as f:
92-
return f.read()
9376

94-
def create(self):
95-
if self.exists:
96-
raise ValueError(f'Checkpoint directory already exists: {self.path}')
77+
def find_slurm_checkpoint(root: Path | str) -> Optional[Path]:
78+
root = Path(root)
9779

98-
self.path.mkdir(parents=True, exist_ok=True)
99-
self.indicator_file.touch()
100-
self.log_file.touch()
101-
if slurm_job_id() is not None:
102-
with open(self.slurm_file, 'w') as f:
103-
f.write(slurm_job_id())
80+
job_id = slurm_job_id()
81+
if job_id is None:
82+
return None
10483

105-
def save_config(self, config: OmegaConf):
106-
if not self.exists:
107-
raise ValueError(f'Checkpoint directory does not exist: {self.path}')
84+
for child in root.iterdir():
85+
if read_slurm_id(child) == job_id:
86+
return child
10887

109-
with open(self.config_file, 'w') as f:
110-
OmegaConf.save(config, f)
88+
return None
11189

112-
def load_config(self) -> OmegaConf:
113-
if not self.is_valid:
114-
raise ValueError(f'Checkpoint directory is not valid: {self.path}')
11590

116-
with open(self.config_file) as f:
117-
return OmegaConf.load(f)
91+
def save_config(config: OmegaConf, run_dir: Path):
92+
with open(run_dir / 'config.yaml', 'w') as f:
93+
OmegaConf.save(config, f)
11894

119-
def __str__(self) -> str:
120-
return str(self.path)
12195

122-
def __repr__(self) -> str:
123-
return f'CheckpointDir({self.path})'
96+
def read_config(run_dir: Path) -> OmegaConf:
97+
with open(run_dir / 'config.yaml') as f:
98+
return OmegaConf.load(f)

dmlcloud/core/pipeline.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from datetime import datetime, timedelta
33
from functools import cached_property
4+
from pathlib import Path
45
from typing import Dict, List, Optional, Union
56

67
import torch
@@ -21,7 +22,7 @@
2122
WandbInitCallback,
2223
WandbLoggerCallback,
2324
)
24-
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
25+
from .checkpoint import find_slurm_checkpoint, generate_checkpoint_path, is_valid_checkpoint_dir
2526
from .distributed import broadcast_object, init, is_root, local_rank
2627
from .stage import Stage
2728

@@ -105,29 +106,39 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
105106

106107
self.name = name
107108

108-
self.checkpoint_dir = None
109+
self.run_dir: Path | None = None
109110
self.resumed = None
110111
self.start_time = None
111112
self.stop_time = None
112113
self.current_stage = None
113114

114-
self.wandb = False
115+
self._wandb = False
115116

116117
self.stages = []
117118
self.callbacks = CallbackList()
118119

119120
self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS)
120121
self.add_callback(GitDiffCallback(), CbPriority.GIT)
121122
self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0
123+
if self.device.type == 'cuda':
124+
self.add_callback(CudaCallback(), CbPriority.CUDA)
122125

123126
if dist.is_gloo_available():
124127
self.gloo_group = dist.new_group(backend='gloo')
125128
else:
126129
warnings.warn('Gloo backend not available. Barriers will not use custom timeouts.')
127130

128131
@property
129-
def checkpointing_enabled(self):
130-
return self.checkpoint_dir is not None
132+
def has_checkpointing(self):
133+
return self.run_dir is not None
134+
135+
@property
136+
def has_wandb(self):
137+
return self._wandb
138+
139+
@property
140+
def has_tensorboard(self):
141+
return self.has_checkpointing
131142

132143
def add_callback(self, callback: Callback, priority: int = 1):
133144
"""
@@ -157,31 +168,25 @@ def enable_checkpointing(
157168
root: str,
158169
resume: bool = False,
159170
):
160-
if self.checkpointing_enabled:
171+
if self.has_checkpointing:
161172
raise ValueError('Checkpointing already enabled')
162173

163-
path = None
164-
if resume and CheckpointDir(root).is_valid:
165-
path = root
174+
if resume and is_valid_checkpoint_dir(root):
175+
self.run_dir = root
166176
self.resumed = True
167177
elif resume and find_slurm_checkpoint(root):
168-
path = find_slurm_checkpoint(root)
178+
self.run_dir = find_slurm_checkpoint(root)
169179
self.resumed = True
170180

171-
if path is None: # no need for a barrier here, dir creation happens in _pre_run()
181+
if self.run_dir is None: # no need for a barrier here, dir creation happens in _pre_run()
172182
path = generate_checkpoint_path(root=root, name=self.name, creation_time=self.start_time)
173-
path = broadcast_object(path)
183+
self.run_dir = broadcast_object(path)
174184
self.resumed = False
175185

176-
self.checkpoint_dir = CheckpointDir(path)
177-
178186
if is_root():
179-
self.add_callback(CheckpointCallback(self.checkpoint_dir.path), CbPriority.CHECKPOINT)
180-
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV)
181-
self.add_callback(TensorboardCallback(self.checkpoint_dir.path), CbPriority.TENSORBOARD)
182-
183-
if self.device.type == 'cuda':
184-
self.add_callback(CudaCallback(), CbPriority.CUDA)
187+
self.add_callback(CheckpointCallback(self.run_dir), CbPriority.CHECKPOINT)
188+
self.add_callback(CsvCallback(self.run_dir, append_stage_name=True), CbPriority.CSV)
189+
self.add_callback(TensorboardCallback(self.run_dir), CbPriority.TENSORBOARD)
185190

186191
def enable_wandb(
187192
self,
@@ -192,7 +197,7 @@ def enable_wandb(
192197
startup_timeout: int = 360,
193198
**kwargs,
194199
):
195-
if self.wandb:
200+
if self._wandb:
196201
raise ValueError('Wandb already enabled')
197202

198203
import wandb # import now to avoid potential long import times later on # noqa
@@ -209,7 +214,7 @@ def enable_wandb(
209214
self.add_callback(init_callback, CbPriority.WANDB_INIT)
210215
self.add_callback(WandbLoggerCallback(), CbPriority.WANDB_LOGGER)
211216

212-
self.wandb = True
217+
self._wandb = True
213218

214219
def barrier(self, timeout=None):
215220
if self.gloo_group is None:
@@ -269,7 +274,7 @@ def _pre_run(self):
269274
callback.pre_run(self)
270275

271276
def _resume_run(self):
272-
dml_logging.info(f'Resuming training from checkpoint: {self.checkpoint_dir}')
277+
dml_logging.info(f'Resuming training from checkpoint: {self.run_dir}')
273278
self.resume_run()
274279

275280
def _post_run(self):

0 commit comments

Comments
 (0)