@@ -125,17 +125,22 @@ def save_artifact(
125125 when type="evals" or type="dataset"
126126 name: Optional name for the artifact. If not provided, uses source basename
127127 or generates a default name for DataFrames. When type="dataset",
128- this is used as the dataset_id.
128+ this is used as the dataset_id. When type="model", this is used as the model name
129+ (will be prefixed with job_id for uniqueness).
129130 type: Optional type of artifact.
130131 - If "evals", saves to eval_results directory and updates job data accordingly.
131132 - If "dataset", saves as a dataset and tracks dataset_id in job data.
133+ - If "model", saves to workspace models directory and creates Model Zoo metadata.
132134 - Otherwise saves to artifacts directory.
133135 config: Optional configuration dict.
134136 When type="evals", can contain column mappings under "evals" key, e.g.:
135137 {"evals": {"input": "input_col", "output": "output_col",
136138 "expected_output": "expected_col", "score": "score_col"}}
137139 When type="dataset", can contain:
138140 {"dataset": {...metadata...}, "suffix": "...", "is_image": bool}
141+ When type="model", can contain:
142+ {"model": {"architecture": "...", "pipeline_tag": "...", "parent_model": "..."}}
143+ or top-level keys: {"architecture": "...", "pipeline_tag": "...", "parent_model": "..."}
139144
140145 Returns:
141146 The destination path on disk.
@@ -278,6 +283,150 @@ def save_artifact(
278283 self .log (f"Evaluation results saved to '{ dest } '" )
279284 return dest
280285
286+ # Handle file path input when type="model"
287+ if type == "model" :
288+ if not isinstance (source_path , str ) or source_path .strip () == "" :
289+ raise ValueError ("source_path must be a non-empty string when type='model'" )
290+ src = os .path .abspath (source_path )
291+ if not os .path .exists (src ):
292+ raise FileNotFoundError (f"Model source does not exist: { src } " )
293+
294+ # Get model-specific parameters from config
295+ model_config = {}
296+ architecture = None
297+ pipeline_tag = None
298+ parent_model = None
299+
300+ if config and isinstance (config , dict ):
301+ # Check for model config in nested dict
302+ if "model" in config and isinstance (config ["model" ], dict ):
303+ model_config = config ["model" ]
304+ # Also allow top-level keys for convenience
305+ if "architecture" in config :
306+ architecture = config ["architecture" ]
307+ if "pipeline_tag" in config :
308+ pipeline_tag = config ["pipeline_tag" ]
309+ if "parent_model" in config :
310+ parent_model = config ["parent_model" ]
311+
312+ # Override with nested model config if present
313+ if model_config :
314+ architecture = model_config .get ("architecture" ) or architecture
315+ pipeline_tag = model_config .get ("pipeline_tag" ) or pipeline_tag
316+ parent_model = model_config .get ("parent_model" ) or parent_model
317+
318+ # Determine base name with job_id prefix for uniqueness
319+ if isinstance (name , str ) and name .strip () != "" :
320+ base_name = f"{ job_id } _{ name } "
321+ else :
322+ base_name = f"{ job_id } _{ os .path .basename (src )} "
323+
324+ # Save to main workspace models directory for Model Zoo visibility
325+ models_dir = dirs .get_models_dir ()
326+ dest = os .path .join (models_dir , base_name )
327+
328+ # Create parent directories
329+ os .makedirs (os .path .dirname (dest ), exist_ok = True )
330+
331+ # Copy file or directory
332+ if os .path .isdir (src ):
333+ if os .path .exists (dest ):
334+ shutil .rmtree (dest )
335+ shutil .copytree (src , dest )
336+ else :
337+ shutil .copy2 (src , dest )
338+
339+ # Initialize model service for metadata and provenance creation
340+ model_service = ModelService (base_name )
341+
342+ # Create Model metadata so it appears in Model Zoo
343+ try :
344+ # Use provided architecture or detect it
345+ if architecture is None :
346+ architecture = model_service .detect_architecture (dest )
347+
348+ # Handle pipeline tag logic
349+ if pipeline_tag is None and parent_model is not None :
350+ # Try to fetch pipeline tag from parent model
351+ pipeline_tag = model_service .fetch_pipeline_tag (parent_model )
352+
353+ # Determine model_filename for single-file models
354+ model_filename = "" if os .path .isdir (dest ) else os .path .basename (dest )
355+
356+ # Prepare json_data with basic info
357+ json_data = {
358+ "job_id" : job_id ,
359+ "description" : f"Model generated by job { job_id } " ,
360+ }
361+
362+ # Add pipeline tag to json_data if provided
363+ if pipeline_tag is not None :
364+ json_data ["pipeline_tag" ] = pipeline_tag
365+
366+ # Use the Model class's generate_model_json method to create metadata
367+ model_service .generate_model_json (
368+ architecture = architecture ,
369+ model_filename = model_filename ,
370+ json_data = json_data
371+ )
372+ self .log (f"Model saved to Model Zoo as '{ base_name } '" )
373+ except Exception as e :
374+ self .log (f"Warning: Model saved but metadata creation failed: { str (e )} " )
375+ # Try to detect architecture for provenance even if metadata creation failed
376+ if architecture is None :
377+ try :
378+ architecture = model_service .detect_architecture (dest )
379+ except Exception :
380+ pass
381+
382+ # Create provenance data
383+ try :
384+ # Create MD5 checksums for all model files
385+ md5_objects = model_service .create_md5_checksums (dest )
386+
387+ # Prepare provenance metadata from job data
388+ job_data = self ._job .get_job_data ()
389+
390+ provenance_metadata = {
391+ "job_id" : job_id ,
392+ "model_name" : parent_model or job_data .get ("model_name" ),
393+ "model_architecture" : architecture ,
394+ "input_model" : parent_model ,
395+ "dataset" : job_data .get ("dataset" ),
396+ "adaptor_name" : job_data .get ("adaptor_name" , None ),
397+ "parameters" : job_data .get ("_config" , {}),
398+ "start_time" : job_data .get ("start_time" , time .strftime ("%Y-%m-%d %H:%M:%S" , time .gmtime ())),
399+ "end_time" : time .strftime ("%Y-%m-%d %H:%M:%S" , time .gmtime ()),
400+ "md5_checksums" : md5_objects ,
401+ }
402+
403+ # Create the _tlab_provenance.json file
404+ provenance_file = model_service .create_provenance_file (
405+ model_path = dest ,
406+ model_name = base_name ,
407+ model_architecture = architecture ,
408+ md5_objects = md5_objects ,
409+ provenance_data = provenance_metadata
410+ )
411+ self .log (f"Provenance file created at: { provenance_file } " )
412+ except Exception as e :
413+ self .log (f"Warning: Model saved but provenance creation failed: { str (e )} " )
414+
415+ # Track in job_data
416+ try :
417+ job_data = self ._job .get_job_data ()
418+ model_list = []
419+ if isinstance (job_data , dict ):
420+ existing = job_data .get ("models" , [])
421+ if isinstance (existing , list ):
422+ model_list = existing
423+ model_list .append (dest )
424+ self ._job .update_job_data_field ("models" , model_list )
425+ except Exception :
426+ pass
427+
428+ return dest
429+
281430 # Handle file path input (original behavior)
282431 if not isinstance (source_path , str ) or source_path .strip () == "" :
283432 raise ValueError ("source_path must be a non-empty string" )
@@ -472,6 +621,9 @@ def save_model(self, source_path: str, name: Optional[str] = None, architecture:
472621 Save a model file or directory to the workspace models directory.
473622 The model will automatically appear in the Model Zoo's Local Models list.
474623
624+ This method is a convenience wrapper around save_artifact with type="model".
625+ For new code, consider using save_artifact directly with type="model".
626+
475627 Args:
476628 source_path: Path to the model file or directory to save
477629 name: Optional name for the model. If not provided, uses source basename.
@@ -485,120 +637,22 @@ def save_model(self, source_path: str, name: Optional[str] = None, architecture:
485637 Returns:
486638 The destination path on disk.
487639 """
488- self ._ensure_initialized ()
489- if not isinstance (source_path , str ) or source_path .strip () == "" :
490- raise ValueError ("source_path must be a non-empty string" )
491- src = os .path .abspath (source_path )
492- if not os .path .exists (src ):
493- raise FileNotFoundError (f"Model source does not exist: { src } " )
494-
495- job_id = self ._job .id # type: ignore[union-attr]
496-
497- # Determine base name with job_id prefix for uniqueness
498- if isinstance (name , str ) and name .strip () != "" :
499- base_name = f"{ job_id } _{ name } "
500- else :
501- base_name = f"{ job_id } _{ os .path .basename (src )} "
502-
503- # Save to main workspace models directory for Model Zoo visibility
504- models_dir = dirs .get_models_dir ()
505- dest = os .path .join (models_dir , base_name )
640+ # Build config dict from parameters
641+ config = {}
642+ if architecture is not None :
643+ config ["architecture" ] = architecture
644+ if pipeline_tag is not None :
645+ config ["pipeline_tag" ] = pipeline_tag
646+ if parent_model is not None :
647+ config ["parent_model" ] = parent_model
506648
507- # Create parent directories
508- os .makedirs (os .path .dirname (dest ), exist_ok = True )
509-
510- # Copy file or directory
511- if os .path .isdir (src ):
512- if os .path .exists (dest ):
513- shutil .rmtree (dest )
514- shutil .copytree (src , dest )
515- else :
516- shutil .copy2 (src , dest )
517-
518- # Create Model metadata so it appears in Model Zoo
519- try :
520- model_service = ModelService (base_name )
521-
522- # Use provided architecture or detect it
523- if architecture is None :
524- architecture = model_service .detect_architecture (dest )
525-
526- # Handle pipeline tag logic
527- if pipeline_tag is None and parent_model is not None :
528- # Try to fetch pipeline tag from parent model
529- pipeline_tag = model_service .fetch_pipeline_tag (parent_model )
530- # Determine model_filename for single-file models
531- model_filename = "" if os .path .isdir (dest ) else os .path .basename (dest )
532-
533- # Prepare json_data with basic info
534- json_data = {
535- "job_id" : job_id ,
536- "description" : f"Model generated by job { job_id } " ,
537- }
538-
539- # Add pipeline tag to json_data if provided
540- if pipeline_tag is not None :
541- json_data ["pipeline_tag" ] = pipeline_tag
542-
543- # Use the Model class's generate_model_json method to create metadata
544- model_service .generate_model_json (
545- architecture = architecture ,
546- model_filename = model_filename ,
547- json_data = json_data
548- )
549- self .log (f"Model saved to Model Zoo as '{ base_name } '" )
550- except Exception as e :
551- self .log (f"Warning: Model saved but metadata creation failed: { str (e )} " )
552-
553- # Create provenance data
554- try :
555- # Create MD5 checksums for all model files
556- md5_objects = model_service .create_md5_checksums (dest )
557-
558- # Prepare provenance metadata from job data
559- job_data = self ._job .get_job_data ()
560-
561- provenance_metadata = {
562- "job_id" : job_id ,
563- "model_name" : parent_model or job_data .get ("model_name" ),
564- "model_architecture" : architecture ,
565- "input_model" : parent_model ,
566- "dataset" : job_data .get ("dataset" ),
567- "adaptor_name" : job_data .get ("adaptor_name" , None ),
568- "parameters" : job_data .get ("_config" , {}),
569- "start_time" : job_data .get ("start_time" , time .strftime ("%Y-%m-%d %H:%M:%S" , time .gmtime ())),
570- "end_time" : time .strftime ("%Y-%m-%d %H:%M:%S" , time .gmtime ()),
571- "md5_checksums" : md5_objects ,
572-
573-
574- }
575-
576- # Create the _tlab_provenance.json file
577- provenance_file = model_service .create_provenance_file (
578- model_path = dest ,
579- model_name = base_name ,
580- model_architecture = architecture ,
581- md5_objects = md5_objects ,
582- provenance_data = provenance_metadata
583- )
584- self .log (f"Provenance file created at: { provenance_file } " )
585- except Exception as e :
586- self .log (f"Warning: Model saved but provenance creation failed: { str (e )} " )
587-
588- # Track in job_data
589- try :
590- job_data = self ._job .get_job_data ()
591- model_list = []
592- if isinstance (job_data , dict ):
593- existing = job_data .get ("models" , [])
594- if isinstance (existing , list ):
595- model_list = existing
596- model_list .append (dest )
597- self ._job .update_job_data_field ("models" , model_list )
598- except Exception :
599- pass
600-
601- return dest
649+ # Use save_artifact with type="model"
650+ return self .save_artifact (
651+ source_path = source_path ,
652+ name = name ,
653+ type = "model" ,
654+ config = config if config else None
655+ )
602656
603657 def error (
604658 self ,
0 commit comments