Skip to content

Commit 608367d

Browse files
committed
merge save_model into save_artifact
1 parent 664680c commit 608367d

File tree

1 file changed

+168
-114
lines changed

1 file changed

+168
-114
lines changed

src/lab/lab_facade.py

Lines changed: 168 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,22 @@ def save_artifact(
125125
when type="evals" or type="dataset"
126126
name: Optional name for the artifact. If not provided, uses source basename
127127
or generates a default name for DataFrames. When type="dataset",
128-
this is used as the dataset_id.
128+
this is used as the dataset_id. When type="model", this is used as the model name
129+
(will be prefixed with job_id for uniqueness).
129130
type: Optional type of artifact.
130131
- If "evals", saves to eval_results directory and updates job data accordingly.
131132
- If "dataset", saves as a dataset and tracks dataset_id in job data.
133+
- If "model", saves to workspace models directory and creates Model Zoo metadata.
132134
- Otherwise saves to artifacts directory.
133135
config: Optional configuration dict.
134136
When type="evals", can contain column mappings under "evals" key, e.g.:
135137
{"evals": {"input": "input_col", "output": "output_col",
136138
"expected_output": "expected_col", "score": "score_col"}}
137139
When type="dataset", can contain:
138140
{"dataset": {...metadata...}, "suffix": "...", "is_image": bool}
141+
When type="model", can contain:
142+
{"model": {"architecture": "...", "pipeline_tag": "...", "parent_model": "..."}}
143+
or top-level keys: {"architecture": "...", "pipeline_tag": "...", "parent_model": "..."}
139144
140145
Returns:
141146
The destination path on disk.
@@ -278,6 +283,150 @@ def save_artifact(
278283
self.log(f"Evaluation results saved to '{dest}'")
279284
return dest
280285

286+
# Handle file path input when type="model"
287+
if type == "model":
288+
if not isinstance(source_path, str) or source_path.strip() == "":
289+
raise ValueError("source_path must be a non-empty string when type='model'")
290+
src = os.path.abspath(source_path)
291+
if not os.path.exists(src):
292+
raise FileNotFoundError(f"Model source does not exist: {src}")
293+
294+
# Get model-specific parameters from config
295+
model_config = {}
296+
architecture = None
297+
pipeline_tag = None
298+
parent_model = None
299+
300+
if config and isinstance(config, dict):
301+
# Check for model config in nested dict
302+
if "model" in config and isinstance(config["model"], dict):
303+
model_config = config["model"]
304+
# Also allow top-level keys for convenience
305+
if "architecture" in config:
306+
architecture = config["architecture"]
307+
if "pipeline_tag" in config:
308+
pipeline_tag = config["pipeline_tag"]
309+
if "parent_model" in config:
310+
parent_model = config["parent_model"]
311+
312+
# Override with nested model config if present
313+
if model_config:
314+
architecture = model_config.get("architecture") or architecture
315+
pipeline_tag = model_config.get("pipeline_tag") or pipeline_tag
316+
parent_model = model_config.get("parent_model") or parent_model
317+
318+
# Determine base name with job_id prefix for uniqueness
319+
if isinstance(name, str) and name.strip() != "":
320+
base_name = f"{job_id}_{name}"
321+
else:
322+
base_name = f"{job_id}_{os.path.basename(src)}"
323+
324+
# Save to main workspace models directory for Model Zoo visibility
325+
models_dir = dirs.get_models_dir()
326+
dest = os.path.join(models_dir, base_name)
327+
328+
# Create parent directories
329+
os.makedirs(os.path.dirname(dest), exist_ok=True)
330+
331+
# Copy file or directory
332+
if os.path.isdir(src):
333+
if os.path.exists(dest):
334+
shutil.rmtree(dest)
335+
shutil.copytree(src, dest)
336+
else:
337+
shutil.copy2(src, dest)
338+
339+
# Initialize model service for metadata and provenance creation
340+
model_service = ModelService(base_name)
341+
342+
# Create Model metadata so it appears in Model Zoo
343+
try:
344+
# Use provided architecture or detect it
345+
if architecture is None:
346+
architecture = model_service.detect_architecture(dest)
347+
348+
# Handle pipeline tag logic
349+
if pipeline_tag is None and parent_model is not None:
350+
# Try to fetch pipeline tag from parent model
351+
pipeline_tag = model_service.fetch_pipeline_tag(parent_model)
352+
353+
# Determine model_filename for single-file models
354+
model_filename = "" if os.path.isdir(dest) else os.path.basename(dest)
355+
356+
# Prepare json_data with basic info
357+
json_data = {
358+
"job_id": job_id,
359+
"description": f"Model generated by job {job_id}",
360+
}
361+
362+
# Add pipeline tag to json_data if provided
363+
if pipeline_tag is not None:
364+
json_data["pipeline_tag"] = pipeline_tag
365+
366+
# Use the Model class's generate_model_json method to create metadata
367+
model_service.generate_model_json(
368+
architecture=architecture,
369+
model_filename=model_filename,
370+
json_data=json_data
371+
)
372+
self.log(f"Model saved to Model Zoo as '{base_name}'")
373+
except Exception as e:
374+
self.log(f"Warning: Model saved but metadata creation failed: {str(e)}")
375+
# Try to detect architecture for provenance even if metadata creation failed
376+
if architecture is None:
377+
try:
378+
architecture = model_service.detect_architecture(dest)
379+
except Exception:
380+
pass
381+
382+
# Create provenance data
383+
try:
384+
# Create MD5 checksums for all model files
385+
md5_objects = model_service.create_md5_checksums(dest)
386+
387+
# Prepare provenance metadata from job data
388+
job_data = self._job.get_job_data()
389+
390+
provenance_metadata = {
391+
"job_id": job_id,
392+
"model_name": parent_model or job_data.get("model_name"),
393+
"model_architecture": architecture,
394+
"input_model": parent_model,
395+
"dataset": job_data.get("dataset"),
396+
"adaptor_name": job_data.get("adaptor_name", None),
397+
"parameters": job_data.get("_config", {}),
398+
"start_time": job_data.get("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())),
399+
"end_time": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
400+
"md5_checksums": md5_objects,
401+
}
402+
403+
# Create the _tlab_provenance.json file
404+
provenance_file = model_service.create_provenance_file(
405+
model_path=dest,
406+
model_name=base_name,
407+
model_architecture=architecture,
408+
md5_objects=md5_objects,
409+
provenance_data=provenance_metadata
410+
)
411+
self.log(f"Provenance file created at: {provenance_file}")
412+
except Exception as e:
413+
self.log(f"Warning: Model saved but provenance creation failed: {str(e)}")
414+
415+
# Track in job_data
416+
try:
417+
job_data = self._job.get_job_data()
418+
model_list = []
419+
if isinstance(job_data, dict):
420+
existing = job_data.get("models", [])
421+
if isinstance(existing, list):
422+
model_list = existing
423+
model_list.append(dest)
424+
self._job.update_job_data_field("models", model_list)
425+
except Exception:
426+
pass
427+
428+
return dest
429+
281430
# Handle file path input (original behavior)
282431
if not isinstance(source_path, str) or source_path.strip() == "":
283432
raise ValueError("source_path must be a non-empty string")
@@ -472,6 +621,9 @@ def save_model(self, source_path: str, name: Optional[str] = None, architecture:
472621
Save a model file or directory to the workspace models directory.
473622
The model will automatically appear in the Model Zoo's Local Models list.
474623
624+
This method is a convenience wrapper around save_artifact with type="model".
625+
For new code, consider using save_artifact directly with type="model".
626+
475627
Args:
476628
source_path: Path to the model file or directory to save
477629
name: Optional name for the model. If not provided, uses source basename.
@@ -485,120 +637,22 @@ def save_model(self, source_path: str, name: Optional[str] = None, architecture:
485637
Returns:
486638
The destination path on disk.
487639
"""
488-
self._ensure_initialized()
489-
if not isinstance(source_path, str) or source_path.strip() == "":
490-
raise ValueError("source_path must be a non-empty string")
491-
src = os.path.abspath(source_path)
492-
if not os.path.exists(src):
493-
raise FileNotFoundError(f"Model source does not exist: {src}")
494-
495-
job_id = self._job.id # type: ignore[union-attr]
496-
497-
# Determine base name with job_id prefix for uniqueness
498-
if isinstance(name, str) and name.strip() != "":
499-
base_name = f"{job_id}_{name}"
500-
else:
501-
base_name = f"{job_id}_{os.path.basename(src)}"
502-
503-
# Save to main workspace models directory for Model Zoo visibility
504-
models_dir = dirs.get_models_dir()
505-
dest = os.path.join(models_dir, base_name)
640+
# Build config dict from parameters
641+
config = {}
642+
if architecture is not None:
643+
config["architecture"] = architecture
644+
if pipeline_tag is not None:
645+
config["pipeline_tag"] = pipeline_tag
646+
if parent_model is not None:
647+
config["parent_model"] = parent_model
506648

507-
# Create parent directories
508-
os.makedirs(os.path.dirname(dest), exist_ok=True)
509-
510-
# Copy file or directory
511-
if os.path.isdir(src):
512-
if os.path.exists(dest):
513-
shutil.rmtree(dest)
514-
shutil.copytree(src, dest)
515-
else:
516-
shutil.copy2(src, dest)
517-
518-
# Create Model metadata so it appears in Model Zoo
519-
try:
520-
model_service = ModelService(base_name)
521-
522-
# Use provided architecture or detect it
523-
if architecture is None:
524-
architecture = model_service.detect_architecture(dest)
525-
526-
# Handle pipeline tag logic
527-
if pipeline_tag is None and parent_model is not None:
528-
# Try to fetch pipeline tag from parent model
529-
pipeline_tag = model_service.fetch_pipeline_tag(parent_model)
530-
# Determine model_filename for single-file models
531-
model_filename = "" if os.path.isdir(dest) else os.path.basename(dest)
532-
533-
# Prepare json_data with basic info
534-
json_data = {
535-
"job_id": job_id,
536-
"description": f"Model generated by job {job_id}",
537-
}
538-
539-
# Add pipeline tag to json_data if provided
540-
if pipeline_tag is not None:
541-
json_data["pipeline_tag"] = pipeline_tag
542-
543-
# Use the Model class's generate_model_json method to create metadata
544-
model_service.generate_model_json(
545-
architecture=architecture,
546-
model_filename=model_filename,
547-
json_data=json_data
548-
)
549-
self.log(f"Model saved to Model Zoo as '{base_name}'")
550-
except Exception as e:
551-
self.log(f"Warning: Model saved but metadata creation failed: {str(e)}")
552-
553-
# Create provenance data
554-
try:
555-
# Create MD5 checksums for all model files
556-
md5_objects = model_service.create_md5_checksums(dest)
557-
558-
# Prepare provenance metadata from job data
559-
job_data = self._job.get_job_data()
560-
561-
provenance_metadata = {
562-
"job_id": job_id,
563-
"model_name": parent_model or job_data.get("model_name"),
564-
"model_architecture": architecture,
565-
"input_model": parent_model,
566-
"dataset": job_data.get("dataset"),
567-
"adaptor_name": job_data.get("adaptor_name", None),
568-
"parameters": job_data.get("_config", {}),
569-
"start_time": job_data.get("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())),
570-
"end_time": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
571-
"md5_checksums": md5_objects,
572-
573-
574-
}
575-
576-
# Create the _tlab_provenance.json file
577-
provenance_file = model_service.create_provenance_file(
578-
model_path=dest,
579-
model_name=base_name,
580-
model_architecture=architecture,
581-
md5_objects=md5_objects,
582-
provenance_data=provenance_metadata
583-
)
584-
self.log(f"Provenance file created at: {provenance_file}")
585-
except Exception as e:
586-
self.log(f"Warning: Model saved but provenance creation failed: {str(e)}")
587-
588-
# Track in job_data
589-
try:
590-
job_data = self._job.get_job_data()
591-
model_list = []
592-
if isinstance(job_data, dict):
593-
existing = job_data.get("models", [])
594-
if isinstance(existing, list):
595-
model_list = existing
596-
model_list.append(dest)
597-
self._job.update_job_data_field("models", model_list)
598-
except Exception:
599-
pass
600-
601-
return dest
649+
# Use save_artifact with type="model"
650+
return self.save_artifact(
651+
source_path=source_path,
652+
name=name,
653+
type="model",
654+
config=config if config else None
655+
)
602656

603657
def error(
604658
self,

0 commit comments

Comments
 (0)