Skip to content

Commit 5b27095

Browse files
authored
Merge pull request #39 from transformerlab/add/get-checkpoints-artifacts-fn
add methods to fetch checkpoints and artifacts paths
2 parents 1cc80f8 + 36134cf commit 5b27095

File tree

4 files changed

+135
-2
lines changed

4 files changed

+135
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.31"
7+
version = "0.0.32"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"

scripts/examples/test_script.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,36 @@ def train():
5757
lab.update_progress(10 + (i + 1) * 10)
5858
print(f"Iteration {i + 1}/8")
5959

60-
# Method 3: Initialize wandb during training (common pattern)
60+
# Save fake checkpoint every 2 iterations
61+
if (i + 1) % 2 == 0:
62+
checkpoint_file = os.path.join(training_config["output_dir"], f"checkpoint_epoch_{i + 1}.txt")
63+
with open(checkpoint_file, "w") as f:
64+
f.write(f"Fake checkpoint for epoch {i + 1}\n")
65+
f.write(f"Model state: iteration_{i + 1}\n")
66+
f.write(f"Loss: {0.5 - (i + 1) * 0.05:.3f}\n")
67+
f.write(f"Accuracy: {0.6 + (i + 1) * 0.04:.3f}\n")
68+
f.write(f"Timestamp: {datetime.now()}\n")
69+
70+
# Save checkpoint using lab facade
71+
saved_checkpoint_path = lab.save_checkpoint(checkpoint_file, f"epoch_{i + 1}_checkpoint.txt")
72+
lab.log(f"Saved checkpoint: {saved_checkpoint_path}")
73+
74+
# Save some fake artifacts
75+
artifact_file = os.path.join(training_config["output_dir"], f"training_metrics_epoch_{i + 1}.json")
76+
with open(artifact_file, "w") as f:
77+
f.write('{\n')
78+
f.write(f' "epoch": {i + 1},\n')
79+
f.write(f' "loss": {0.5 - (i + 1) * 0.05:.3f},\n')
80+
f.write(f' "accuracy": {0.6 + (i + 1) * 0.04:.3f},\n')
81+
f.write(f' "learning_rate": {2e-5},\n')
82+
f.write(f' "batch_size": {8},\n')
83+
f.write(f' "timestamp": "{datetime.now().isoformat()}"\n')
84+
f.write('}\n')
85+
86+
# Save artifact using lab facade
87+
saved_artifact_path = lab.save_artifact(artifact_file, f"metrics_epoch_{i + 1}.json")
88+
lab.log(f"Saved artifact: {saved_artifact_path}")
89+
6190
if i == 3: # Initialize wandb halfway through training
6291
try:
6392
import wandb
@@ -97,6 +126,30 @@ def train():
97126
training_duration = end_time - start_time
98127
lab.log(f"Training completed in {training_duration}")
99128

129+
# Save final artifacts
130+
final_model_file = os.path.join(training_config["output_dir"], "final_model_summary.txt")
131+
with open(final_model_file, "w") as f:
132+
f.write("Final Model Summary\n")
133+
f.write("==================\n")
134+
f.write(f"Training Duration: {training_duration}\n")
135+
f.write("Final Loss: 0.15\n")
136+
f.write("Final Accuracy: 0.92\n")
137+
f.write(f"Model: {training_config['model_name']}\n")
138+
f.write(f"Dataset: {training_config['dataset']}\n")
139+
f.write(f"Completed at: {end_time}\n")
140+
141+
# Save final model as artifact
142+
final_model_path = lab.save_artifact(final_model_file, "final_model_summary.txt")
143+
lab.log(f"Saved final model summary: {final_model_path}")
144+
145+
# Save training configuration as artifact
146+
config_file = os.path.join(training_config["output_dir"], "training_config.json")
147+
import json
148+
with open(config_file, "w") as f:
149+
json.dump(training_config, f, indent=2)
150+
151+
config_artifact_path = lab.save_artifact(config_file, "training_config.json")
152+
lab.log(f"Saved training config: {config_artifact_path}")
100153
# Get the captured wandb URL from job data for reporting
101154
job_data = lab.job.get_job_data()
102155
captured_wandb_url = job_data.get("wandb_run_url", "None")

src/lab/job.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,58 @@ def get_next_queued_job(cls):
235235
return queued_jobs[0][1]
236236
return None
237237

238+
def get_checkpoints_dir(self):
239+
"""
240+
Get the checkpoints directory path for this job.
241+
"""
242+
return dirs.get_job_checkpoints_dir(self.id)
243+
244+
def get_artifacts_dir(self):
245+
"""
246+
Get the artifacts directory path for this job.
247+
"""
248+
return dirs.get_job_artifacts_dir(self.id)
249+
250+
def get_checkpoint_paths(self):
251+
"""
252+
Get list of checkpoint file paths for this job.
253+
Returns list of checkpoint paths from job_data or scans directory.
254+
"""
255+
try:
256+
# Scan the checkpoints directory
257+
checkpoints_dir = self.get_checkpoints_dir()
258+
if os.path.exists(checkpoints_dir):
259+
checkpoint_files = []
260+
for item in os.listdir(checkpoints_dir):
261+
item_path = os.path.join(checkpoints_dir, item)
262+
if os.path.isfile(item_path):
263+
checkpoint_files.append(item_path)
264+
return sorted(checkpoint_files)
265+
266+
return []
267+
except Exception:
268+
return []
269+
270+
271+
def get_artifact_paths(self):
272+
"""
273+
Get list of artifact file paths for this job.
274+
Returns list of artifact paths from job_data or scans directory.
275+
"""
276+
try:
277+
# Scan the artifacts directory
278+
artifacts_dir = self.get_artifacts_dir()
279+
if os.path.exists(artifacts_dir):
280+
artifact_files = []
281+
for item in os.listdir(artifacts_dir):
282+
item_path = os.path.join(artifacts_dir, item)
283+
if os.path.isfile(item_path):
284+
artifact_files.append(item_path)
285+
return sorted(artifact_files)
286+
except Exception:
287+
return []
288+
return []
289+
238290
def delete(self):
239291
"""
240292
Mark this job as deleted.

src/lab/lab_facade.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,34 @@ def job(self) -> Job:
293293
self._ensure_initialized()
294294
return self._job # type: ignore[return-value]
295295

296+
def get_checkpoints_dir(self) -> str:
297+
"""
298+
Get the checkpoints directory path for the current job.
299+
"""
300+
self._ensure_initialized()
301+
return self._job.get_checkpoints_dir() # type: ignore[union-attr]
302+
303+
def get_artifacts_dir(self) -> str:
304+
"""
305+
Get the artifacts directory path for the current job.
306+
"""
307+
self._ensure_initialized()
308+
return self._job.get_artifacts_dir() # type: ignore[union-attr]
309+
310+
def get_checkpoint_paths(self) -> list[str]:
311+
"""
312+
Get list of checkpoint file paths for the current job.
313+
"""
314+
self._ensure_initialized()
315+
return self._job.get_checkpoint_paths() # type: ignore[union-attr]
316+
317+
def get_artifact_paths(self) -> list[str]:
318+
"""
319+
Get list of artifact file paths for the current job.
320+
"""
321+
self._ensure_initialized()
322+
return self._job.get_artifact_paths() # type: ignore[union-attr]
323+
296324
@property
297325
def experiment(self) -> Experiment:
298326
self._ensure_initialized()

0 commit comments

Comments
 (0)