3535)
3636from transformers .models .gpt2 .tokenization_gpt2_fast import GPT2TokenizerFast
3737from typing_extensions import Annotated
38- from zenml import ArtifactConfig , log_model_metadata , save_artifact , step
38+ from zenml import ArtifactConfig , log_metadata , save_artifact , step
3939from zenml .client import Client
4040from zenml .enums import ArtifactType
4141
@@ -524,8 +524,9 @@ def run_training(args: Configuration, train_data, val_data, hf_token):
524524 try :
525525 if args .push_to_hub :
526526 commit_info = trainer .push_to_hub ()
527- log_model_metadata (
528- metadata = {"trainer_commit_info" : str (commit_info )}
527+ log_metadata (
528+ metadata = {"trainer_commit_info" : str (commit_info )},
529+ infer_model = True
529530 )
530531 else :
531532 trainer .save_model (args .output_dir )
@@ -535,8 +536,9 @@ def run_training(args: Configuration, train_data, val_data, hf_token):
535536 commit_info = trainer .model .push_to_hub (
536537 repo_id = args .output_peft_repo_id , token = hf_token
537538 )
538- log_model_metadata (
539- metadata = {"model_commit_info" : str (commit_info )}
539+ log_metadata (
540+ metadata = {"model_commit_info" : str (commit_info )},
541+ infer_model = True
540542 )
541543 except Exception as e :
542544 print ("Exception while pushing or saving" )
@@ -580,11 +582,15 @@ def merge_and_push(
580582
581583 model_id_merged = f"{ peft_model_id } -merged"
582584 commit_info = tokenizer .push_to_hub (model_id_merged , token = hf_token )
583- log_model_metadata (
584- metadata = {"merged_tokenizer_commit_info" : str (commit_info )}
585+ log_metadata (
586+ metadata = {"merged_tokenizer_commit_info" : str (commit_info )},
587+ infer_model = True
585588 )
586589 commit_info = final_model .push_to_hub (model_id_merged , token = hf_token )
587- log_model_metadata (metadata = {"merged_model_commit_info" : str (commit_info )})
590+ log_metadata (
591+ metadata = {"merged_model_commit_info" : str (commit_info )},
592+ infer_model = True
593+ )
588594
589595
590596@step (output_materializers = {"trainer_obj" : HFTrainerMaterializer })
0 commit comments