Skip to content

Commit 410e94f

Browse files
authored
Merge pull request #40 from transformerlab/add/model-checkpoint-management
Add/model checkpoint management
2 parents fb3d106 + f63930a commit 410e94f

File tree

3 files changed

+278
-1
lines changed

3 files changed

+278
-1
lines changed

scripts/examples/test_script.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ def train():
164164
except Exception:
165165
pass
166166

167+
# Save the trained model
168+
model_dir = os.path.join(training_config["output_dir"], "final_model")
169+
os.makedirs(model_dir, exist_ok=True)
170+
171+
# Create dummy model files to simulate a saved model
172+
with open(os.path.join(model_dir, "config.json"), "w") as f:
173+
f.write('{"model": "SmolLM-135M-Instruct", "params": 135000000}')
174+
with open(os.path.join(model_dir, "pytorch_model.bin"), "w") as f:
175+
f.write("dummy binary model data")
176+
177+
saved_path = lab.save_model(model_dir, name="trained_model")
178+
lab.log(f"✅ Model saved to job models directory: {saved_path}")
179+
167180
print("Complete")
168181

169182
# Complete the job in TransformerLab via facade
@@ -176,6 +189,7 @@ def train():
176189
"output_dir": os.path.join(
177190
training_config["output_dir"], f"final_model_{lab.job.id}"
178191
),
192+
"saved_model_path": saved_path,
179193
"wandb_url": captured_wandb_url,
180194
}
181195

src/lab/lab_facade.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
from time import time
34
from typing import Optional, Dict, Any
45
import os
56
import shutil
67

78
from .experiment import Experiment
89
from .job import Job
910
from . import dirs
10-
11+
from .model import Model as ModelService
1112

1213
class Lab:
1314
"""
@@ -177,6 +178,139 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
177178

178179
return dest
179180

181+
def save_model(self, source_path: str, name: Optional[str] = None, architecture: Optional[str] = None, pipeline_tag: Optional[str] = None, parent_model: Optional[str] = None) -> str:
182+
"""
183+
Save a model file or directory to the workspace models directory.
184+
The model will automatically appear in the Model Zoo's Local Models list.
185+
186+
Args:
187+
source_path: Path to the model file or directory to save
188+
name: Optional name for the model. If not provided, uses source basename.
189+
The final model name will be prefixed with the job_id for uniqueness.
190+
architecture: Optional architecture string. If not provided, will attempt to
191+
detect from config.json for directory-based models.
192+
pipeline_tag: Optional pipeline tag. If not provided and parent_model is given,
193+
will attempt to fetch from parent model on HuggingFace.
194+
parent_model: Optional parent model name/ID for provenance tracking.
195+
196+
Returns:
197+
The destination path on disk.
198+
"""
199+
self._ensure_initialized()
200+
if not isinstance(source_path, str) or source_path.strip() == "":
201+
raise ValueError("source_path must be a non-empty string")
202+
src = os.path.abspath(source_path)
203+
if not os.path.exists(src):
204+
raise FileNotFoundError(f"Model source does not exist: {src}")
205+
206+
job_id = self._job.id # type: ignore[union-attr]
207+
208+
# Determine base name with job_id prefix for uniqueness
209+
if isinstance(name, str) and name.strip() != "":
210+
base_name = f"{job_id}_{name}"
211+
else:
212+
base_name = f"{job_id}_{os.path.basename(src)}"
213+
214+
# Save to main workspace models directory for Model Zoo visibility
215+
models_dir = dirs.get_models_dir()
216+
dest = os.path.join(models_dir, base_name)
217+
218+
# Create parent directories
219+
os.makedirs(os.path.dirname(dest), exist_ok=True)
220+
221+
# Copy file or directory
222+
if os.path.isdir(src):
223+
if os.path.exists(dest):
224+
shutil.rmtree(dest)
225+
shutil.copytree(src, dest)
226+
else:
227+
shutil.copy2(src, dest)
228+
229+
# Create Model metadata so it appears in Model Zoo
230+
try:
231+
model_service = ModelService(base_name)
232+
233+
# Use provided architecture or detect it
234+
if architecture is None:
235+
architecture = model_service.detect_architecture(dest)
236+
237+
# Handle pipeline tag logic
238+
if pipeline_tag is None and parent_model is not None:
239+
# Try to fetch pipeline tag from parent model
240+
pipeline_tag = model_service.fetch_pipeline_tag(parent_model)
241+
# Determine model_filename for single-file models
242+
model_filename = "" if os.path.isdir(dest) else os.path.basename(dest)
243+
244+
# Prepare json_data with basic info
245+
json_data = {
246+
"job_id": job_id,
247+
"description": f"Model generated by job {job_id}",
248+
}
249+
250+
# Add pipeline tag to json_data if provided
251+
if pipeline_tag is not None:
252+
json_data["pipeline_tag"] = pipeline_tag
253+
254+
# Use the Model class's generate_model_json method to create metadata
255+
model_service.generate_model_json(
256+
architecture=architecture,
257+
model_filename=model_filename,
258+
json_data=json_data
259+
)
260+
self.log(f"Model saved to Model Zoo as '{base_name}'")
261+
except Exception as e:
262+
self.log(f"Warning: Model saved but metadata creation failed: {str(e)}")
263+
264+
# Create provenance data
265+
try:
266+
# Create MD5 checksums for all model files
267+
md5_objects = model_service.create_md5_checksums(dest)
268+
269+
# Prepare provenance metadata from job data
270+
job_data = self._job.get_job_data()
271+
272+
provenance_metadata = {
273+
"job_id": job_id,
274+
"model_name": parent_model or job_data.get("model_name"),
275+
"model_architecture": architecture,
276+
"input_model": parent_model,
277+
"dataset": job_data.get("dataset"),
278+
"adaptor_name": job_data.get("adaptor_name", None),
279+
"parameters": job_data.get("_config", {}),
280+
"start_time": job_data.get("start_time", ""),
281+
"end_time": time.strftime("%Y-%m-%d %H:%M:%S"),
282+
"md5_checksums": md5_objects,
283+
284+
285+
}
286+
287+
# Create the _tlab_provenance.json file
288+
provenance_file = model_service.create_provenance_file(
289+
model_path=dest,
290+
model_name=base_name,
291+
model_architecture=architecture,
292+
md5_objects=md5_objects,
293+
provenance_data=provenance_metadata
294+
)
295+
self.log(f"Provenance file created at: {provenance_file}")
296+
except Exception as e:
297+
self.log(f"Warning: Model saved but provenance creation failed: {str(e)}")
298+
299+
# Track in job_data
300+
try:
301+
job_data = self._job.get_job_data()
302+
model_list = []
303+
if isinstance(job_data, dict):
304+
existing = job_data.get("models", [])
305+
if isinstance(existing, list):
306+
model_list = existing
307+
model_list.append(dest)
308+
self._job.update_job_data_field("models", model_list)
309+
except Exception:
310+
pass
311+
312+
return dest
313+
180314
def error(
181315
self,
182316
message: str = "",

src/lab/model.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,135 @@ def import_model(self, model_name, model_path):
6565
"""
6666
self.generate_model_json(model_name, model_path)
6767

68+
def detect_architecture(self, model_path: str) -> str:
69+
"""
70+
Detect the model architecture from a model directory's config.json.
71+
72+
Args:
73+
model_path: Path to the model directory or file
74+
75+
Returns:
76+
The model architecture (e.g., 'LlamaForCausalLM') or 'Unknown' if not found
77+
"""
78+
architecture = "Unknown"
79+
80+
if os.path.isdir(model_path):
81+
config_path = os.path.join(model_path, "config.json")
82+
if os.path.exists(config_path):
83+
try:
84+
with open(config_path, 'r') as f:
85+
config = json.load(f)
86+
architectures = config.get("architectures", [])
87+
if architectures:
88+
architecture = architectures[0]
89+
except Exception:
90+
pass
91+
92+
return architecture
93+
94+
def fetch_pipeline_tag(self, parent_model: str) -> str | None:
95+
"""
96+
Fetch the pipeline tag from a parent model on HuggingFace.
97+
98+
Args:
99+
parent_model: The HuggingFace model ID to fetch the pipeline tag from
100+
101+
Returns:
102+
The pipeline tag string if found, None otherwise
103+
"""
104+
try:
105+
from huggingface_hub import HfApi
106+
api = HfApi()
107+
model_info = api.model_info(parent_model)
108+
return model_info.pipeline_tag
109+
except Exception as e:
110+
print(f"Could not fetch pipeline tag from parent model '{parent_model}': {type(e).__name__}: {e}")
111+
return None
112+
113+
def create_md5_checksums(self, model_path: str) -> list:
114+
"""
115+
Create MD5 checksums for all files in the model directory.
116+
117+
Args:
118+
model_path: Path to the model directory
119+
120+
Returns:
121+
List of dicts with 'file_path' and 'md5_hash' keys
122+
"""
123+
import hashlib
124+
125+
def compute_md5(file_path):
126+
md5 = hashlib.md5()
127+
with open(file_path, "rb") as f:
128+
while chunk := f.read(8192):
129+
md5.update(chunk)
130+
return md5.hexdigest()
131+
132+
md5_objects = []
133+
134+
if not os.path.isdir(model_path):
135+
print(f"Model path '{model_path}' is not a directory, skipping MD5 checksum creation")
136+
return md5_objects
137+
138+
for root, _, files in os.walk(model_path):
139+
for file in files:
140+
file_path = os.path.join(root, file)
141+
try:
142+
md5_hash = compute_md5(file_path)
143+
md5_objects.append({"file_path": file_path, "md5_hash": md5_hash})
144+
except Exception as e:
145+
print(f"Warning: Could not compute MD5 for {file_path}: {str(e)}")
146+
147+
return md5_objects
148+
149+
def create_provenance_file(self, model_path: str, model_name: str = None, model_architecture: str = None,
150+
md5_objects: list = None, provenance_data: dict = None) -> str:
151+
"""
152+
Create a _tlab_provenance.json file containing model provenance data.
153+
154+
Args:
155+
model_path: Path to the model directory
156+
model_name: Name of the model
157+
model_architecture: Architecture of the model
158+
md5_objects: List of MD5 checksums from create_md5_checksums()
159+
provenance_data: Optional dict with additional provenance data. Expected keys include:
160+
- job_id: ID of the job that created this model
161+
- input_model: Name of the base/parent model used
162+
- dataset: Name of the dataset used for training
163+
- adaptor_name: Name of the adapter if applicable
164+
- parameters: Training configuration parameters
165+
- start_time: When training/processing started
166+
167+
Returns:
168+
Path to the created provenance file
169+
"""
170+
import time
171+
172+
# Start with base provenance data matching the structure from train.py
173+
final_provenance = {
174+
"model_name": model_name,
175+
"model_architecture": model_architecture,
176+
"job_id": None,
177+
"input_model": None,
178+
"dataset": None,
179+
"adaptor_name": None,
180+
"parameters": None,
181+
"start_time": "",
182+
"end_time": time.strftime("%Y-%m-%d %H:%M:%S"),
183+
"md5_checksums": md5_objects,
184+
}
185+
186+
# Merge in any additional provenance data provided
187+
if provenance_data and isinstance(provenance_data, dict):
188+
final_provenance.update(provenance_data)
189+
190+
# Write provenance to file
191+
provenance_path = os.path.join(model_path, "_tlab_provenance.json")
192+
with open(provenance_path, "w") as f:
193+
json.dump(final_provenance, f, indent=2)
194+
195+
return provenance_path
196+
68197
def generate_model_json(
69198
self,
70199
architecture: str,

0 commit comments

Comments
 (0)