@@ -544,7 +544,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
544544
545545 self .upload_zip (model_type , model_path )
546546
547- def deploy_yolonas (self , model_type : str , model_path : str ) -> None :
547+ def deploy_yolonas (self , model_type : str , model_path : str , filename : str = "weights/best.pt" ) -> None :
548548 try :
549549 import torch
550550 except ImportError :
@@ -553,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
553553 " Please install it with `pip install torch`"
554554 )
555555
556- model = torch .load (os .path .join (model_path , "weights/best.pt" ), map_location = "cpu" )
556+ model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
557557 class_names = model ["processing_params" ]["class_names" ]
558558
559559 opt_path = os .path .join (model_path , "opt.yaml" )
@@ -586,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
586586 with open (os .path .join (model_path , "model_artifacts.json" ), "w" ) as fp :
587587 json .dump (model_artifacts , fp )
588588
589- shutil .copy (os .path .join (model_path , "weights/best.pt" ), os .path .join (model_path , "state_dict.pt" ))
589+ shutil .copy (os .path .join (model_path , filename ), os .path .join (model_path , "state_dict.pt" ))
590590
591591 list_files = [
592592 "results.json" ,
@@ -604,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
604604 compress_type = zipfile .ZIP_DEFLATED ,
605605 )
606606 else :
607- if file in ["model_artifacts.json" , "best.pt" ]:
607+ if file in ["model_artifacts.json" , filename ]:
608608 raise (ValueError (f"File { file } not found. Please make sure to provide a" " valid model path." ))
609609
610610 self .upload_zip (model_type , model_path )
0 commit comments