@@ -419,11 +419,13 @@ def live_plot(epochs, mAP, loss, title=""):
419419 return self .model
420420
421421 # @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")])
422- def deploy (self , model_type : str , model_path : str ) -> None :
423- """Uploads provided weights file to Roboflow
422+ def deploy (self , model_type : str , model_path : str , filename : str = "weights/best.pt" ) -> None :
423+ """Uploads provided weights file to Roboflow.
424424
425425 Args:
426- model_path (str): File path to model weights to be uploaded
426+ model_type (str): The type of the model to be deployed.
427+ model_path (str): File path to the model weights to be uploaded.
428+ filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
427429 """
428430
429431 supported_models = ["yolov5" , "yolov7-seg" , "yolov8" , "yolov9" , "yolonas" ]
@@ -432,7 +434,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
432434 raise (ValueError (f"Model type { model_type } not supported. Supported models are" f" { supported_models } " ))
433435
434436 if "yolonas" in model_type :
435- self .deploy_yolonas (model_type , model_path )
437+ self .deploy_yolonas (model_type , model_path , filename )
436438 return
437439
438440 if "yolov8" in model_type :
@@ -457,7 +459,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
457459 " Please install it with `pip install torch`"
458460 )
459461
460- model = torch .load (os .path .join (model_path , "weights/best.pt" ))
462+ model = torch .load (os .path .join (model_path , filename ))
461463
462464 if isinstance (model ["model" ].names , list ):
463465 class_names = model ["model" ].names
@@ -542,7 +544,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
542544
543545 self .upload_zip (model_type , model_path )
544546
545- def deploy_yolonas (self , model_type : str , model_path : str ) -> None :
547+ def deploy_yolonas (self , model_type : str , model_path : str , filename : str = "weights/best.pt" ) -> None :
546548 try :
547549 import torch
548550 except ImportError :
@@ -551,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
551553 " Please install it with `pip install torch`"
552554 )
553555
554- model = torch .load (os .path .join (model_path , "weights/best.pt" ), map_location = "cpu" )
556+ model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
555557 class_names = model ["processing_params" ]["class_names" ]
556558
557559 opt_path = os .path .join (model_path , "opt.yaml" )
@@ -584,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
584586 with open (os .path .join (model_path , "model_artifacts.json" ), "w" ) as fp :
585587 json .dump (model_artifacts , fp )
586588
587- shutil .copy (os .path .join (model_path , "weights/best.pt" ), os .path .join (model_path , "state_dict.pt" ))
589+ shutil .copy (os .path .join (model_path , filename ), os .path .join (model_path , "state_dict.pt" ))
588590
589591 list_files = [
590592 "results.json" ,
@@ -602,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
602604 compress_type = zipfile .ZIP_DEFLATED ,
603605 )
604606 else :
605- if file in ["model_artifacts.json" , "best.pt" ]:
607+ if file in ["model_artifacts.json" , filename ]:
606608 raise (ValueError (f"File { file } not found. Please make sure to provide a" " valid model path." ))
607609
608610 self .upload_zip (model_type , model_path )
0 commit comments