Skip to content

Commit 3aa9f05

Browse files
authored
Merge pull request #51 from transformerlab/add/train-resume-chceckpoint
Update trl script to be able to resume training from checkpoint
2 parents 08c4636 + bd2c69c commit 3aa9f05

File tree

3 files changed

+98
-19
lines changed

3 files changed

+98
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.43"
7+
version = "0.0.44"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"

scripts/examples/trl_train_script.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import os
8+
import argparse
89
from datetime import datetime
910
from time import sleep
1011
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
@@ -122,6 +123,11 @@ def train_with_trl(quick_test=True):
122123
lab.init()
123124
lab.set_config(training_config)
124125

126+
# Check if we should resume from a checkpoint
127+
checkpoint = lab.get_checkpoint_to_resume()
128+
if checkpoint:
129+
lab.log(f"📁 Resuming training from checkpoint: {checkpoint}")
130+
125131
# Log start time
126132
start_time = datetime.now()
127133
mode = "Quick test" if quick_test else "Full training"
@@ -162,17 +168,17 @@ def train_with_trl(quick_test=True):
162168
lab.log("Loading model and tokenizer...")
163169
try:
164170
from transformers import AutoTokenizer, AutoModelForCausalLM
165-
171+
166172
model_name = training_config["model_name"]
167173
tokenizer = AutoTokenizer.from_pretrained(model_name)
168174
model = AutoModelForCausalLM.from_pretrained(model_name)
169-
175+
170176
# Add pad token if it doesn't exist
171177
if tokenizer.pad_token is None:
172178
tokenizer.pad_token = tokenizer.eos_token
173179

174180
lab.log(f"Loaded model: {model_name}")
175-
181+
176182
except ImportError:
177183
lab.log("⚠️ Transformers not available, skipping real training")
178184
lab.finish("Training skipped - transformers not available")
@@ -207,6 +213,8 @@ def train_with_trl(quick_test=True):
207213
remove_unused_columns=False,
208214
push_to_hub=False,
209215
dataset_text_field="text", # Move dataset_text_field to SFTConfig
216+
resume_from_checkpoint=checkpoint if checkpoint else None,
217+
bf16=False, # Disable bf16 for compatibility with older GPUs
210218
# Enable automatic checkpoint saving
211219
save_total_limit=3, # Keep only the last 3 checkpoints to save disk space
212220
save_strategy="steps", # Save checkpoints every save_steps
@@ -440,15 +448,18 @@ def train_with_trl(quick_test=True):
440448

441449

442450
if __name__ == "__main__":
443-
import sys
444451

445-
# Check if user wants full training or quick test
446-
quick_test = False # Default to quick test
447-
if len(sys.argv) > 1 and sys.argv[1] == "--quick-training":
448-
quick_test = True
452+
parser = argparse.ArgumentParser(description="Train a model with automatic checkpoint resume support.")
453+
parser.add_argument("--quick-training", action="store_true", help="Run in quick test mode")
454+
455+
args = parser.parse_args()
456+
457+
quick_test = args.quick_training
458+
459+
if quick_test:
449460
print("🚀 Running quick test mode...")
450461
else:
451-
print("🚀 Running full training mode (use --quick-training for quick test)...")
452-
462+
print("🚀 Running full training mode...")
463+
453464
result = train_with_trl(quick_test=quick_test)
454465
print("Training result:", result)

src/lab/lab_facade.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,73 @@ def update_progress(self, progress: int) -> None:
9191
# Check for wandb URL on every progress update
9292
self._check_and_capture_wandb_url()
9393

94+
# ------------- checkpoint resume support -------------
95+
def get_checkpoint_to_resume(self) -> Optional[str]:
96+
"""
97+
Get the checkpoint path to resume training from.
98+
99+
This method checks for checkpoint resume information stored in the job data
100+
when resuming training from a checkpoint.
101+
102+
Returns:
103+
Optional[str]: The full path to the checkpoint to resume from, or None if no
104+
checkpoint resume is requested.
105+
"""
106+
if not self._job:
107+
return None
108+
109+
job_data = self._job.get_job_data()
110+
if not job_data:
111+
return None
112+
113+
parent_job_id = job_data.get('parent_job_id')
114+
checkpoint_name = job_data.get('resumed_from_checkpoint')
115+
116+
if not parent_job_id or not checkpoint_name:
117+
return None
118+
119+
# Build the checkpoint path from parent job's checkpoints directory
120+
checkpoint_path = self.get_parent_job_checkpoint_path(parent_job_id, checkpoint_name)
121+
122+
# Verify the checkpoint exists
123+
if checkpoint_path and os.path.exists(checkpoint_path):
124+
return checkpoint_path
125+
126+
return None
127+
128+
def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: str) -> Optional[str]:
129+
"""
130+
Get the full path to a checkpoint from a parent job.
131+
132+
This is a helper method that constructs the path to a specific checkpoint
133+
from a parent job's checkpoints directory.
134+
135+
Args:
136+
parent_job_id (str): The ID of the parent job that created the checkpoint
137+
checkpoint_name (str): The name of the checkpoint file or directory
138+
139+
Returns:
140+
Optional[str]: The full path to the checkpoint, or None if it doesn't exist
141+
"""
142+
try:
143+
checkpoints_dir = dirs.get_job_checkpoints_dir(parent_job_id)
144+
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
145+
146+
# Security check: ensure the checkpoint path is within the checkpoints directory
147+
checkpoint_path_normalized = os.path.normpath(checkpoint_path)
148+
checkpoints_dir_normalized = os.path.normpath(checkpoints_dir)
149+
150+
if not checkpoint_path_normalized.startswith(checkpoints_dir_normalized + os.sep):
151+
return None
152+
153+
if os.path.exists(checkpoint_path_normalized):
154+
return checkpoint_path_normalized
155+
156+
return None
157+
except Exception as e:
158+
print(f"Error getting parent job checkpoint path: {str(e)}")
159+
return None
160+
94161
# ------------- completion -------------
95162
def finish(
96163
self,
@@ -506,8 +573,8 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
506573
try:
507574
if hasattr(df, "to_pandas") and callable(getattr(df, "to_pandas")):
508575
df = df.to_pandas()
509-
except Exception:
510-
pass
576+
except Exception as e:
577+
print(f"Warning: Failed to convert dataset to pandas DataFrame: {str(e)}")
511578

512579
# Prepare dataset directory
513580
dataset_id_safe = dataset_id.strip()
@@ -562,16 +629,17 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
562629
)
563630
except Exception as e:
564631
# Do not fail the save if metadata write fails; log to job data
632+
print(f"Warning: Failed to create dataset metadata: {str(e)}")
565633
try:
566634
self._job.update_job_data_field("dataset_metadata_error", str(e)) # type: ignore[union-attr]
567-
except Exception:
568-
pass
635+
except Exception as e2:
636+
print(f"Warning: Failed to log dataset metadata error: {str(e2)}")
569637

570638
# Track dataset on the job for provenance
571639
try:
572640
self._job.update_job_data_field("dataset_id", dataset_id_safe) # type: ignore[union-attr]
573-
except Exception:
574-
pass
641+
except Exception as e:
642+
print(f"Warning: Failed to track dataset in job_data: {str(e)}")
575643

576644
self.log(f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id_safe}'")
577645
return output_path
@@ -615,8 +683,8 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
615683
ckpt_list.append(dest)
616684
self._job.update_job_data_field("checkpoints", ckpt_list)
617685
self._job.update_job_data_field("latest_checkpoint", dest)
618-
except Exception:
619-
pass
686+
except Exception as e:
687+
print(f"Warning: Failed to track checkpoint in job_data: {str(e)}")
620688

621689
return dest
622690

0 commit comments

Comments
 (0)