11import copy
22import json
33import os
4+ import shutil
45import sys
56import time
67import zipfile
@@ -425,11 +426,15 @@ def deploy(self, model_type: str, model_path: str) -> None:
425426 model_path (str): File path to model weights to be uploaded
426427 """
427428
428- supported_models = ["yolov5" , "yolov7-seg" , "yolov8" , "yolov9" ]
429+ supported_models = ["yolov5" , "yolov7-seg" , "yolov8" , "yolov9" , "yolonas" ]
429430
430431 if not any (supported_model in model_type for supported_model in supported_models ):
431432 raise (ValueError (f"Model type { model_type } not supported. Supported models are" f" { supported_models } " ))
432433
434+ if "yolonas" in model_type :
435+ self .deploy_yolonas (model_type , model_path )
436+ return
437+
433438 if "yolov8" in model_type :
434439 try :
435440 import torch
@@ -516,15 +521,15 @@ def deploy(self, model_type: str, model_path: str) -> None:
516521
517522 torch .save (model ["model" ].state_dict (), os .path .join (model_path , "state_dict.pt" ))
518523
519- lista_files = [
524+ list_files = [
520525 "results.csv" ,
521526 "results.png" ,
522527 "model_artifacts.json" ,
523528 "state_dict.pt" ,
524529 ]
525530
526531 with zipfile .ZipFile (os .path .join (model_path , "roboflow_deploy.zip" ), "w" ) as zipMe :
527- for file in lista_files :
532+ for file in list_files :
528533 if os .path .exists (os .path .join (model_path , file )):
529534 zipMe .write (
530535 os .path .join (model_path , file ),
@@ -535,6 +540,74 @@ def deploy(self, model_type: str, model_path: str) -> None:
535540 if file in ["model_artifacts.json" , "state_dict.pt" ]:
536541 raise (ValueError (f"File { file } not found. Please make sure to provide a" " valid model path." ))
537542
543+ self .upload_zip (model_type , model_path )
544+
545+ def deploy_yolonas (self , model_type : str , model_path : str ) -> None :
546+ try :
547+ import torch
548+ except ImportError :
549+ raise (
550+ "The torch python package is required to deploy yolonas models."
551+ " Please install it with `pip install torch`"
552+ )
553+
554+ model = torch .load (os .path .join (model_path , "weights/best.pt" ), map_location = "cpu" )
555+ class_names = model ["processing_params" ]["class_names" ]
556+
557+ opt_path = os .path .join (model_path , "opt.yaml" )
558+ if not os .path .exists (opt_path ):
559+ raise RuntimeError (
560+ f"You must create an opt.yaml file at { os .path .join (model_path , '' )} of the format:\n "
561+ f"imgsz: <resolution of model>\n "
562+ f"batch_size: <batch size of inference model>\n "
563+ f"architecture: <one of [yolo_nas_s, yolo_nas_m, yolo_nas_l]."
564+ f"s, m, l refer to small, medium, large architecture sizes, respectively>\n "
565+ )
566+ with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
567+ opts = yaml .safe_load (stream )
568+ required_keys = ["imgsz" , "batch_size" , "architecture" ]
569+ for key in required_keys :
570+ if key not in opts :
571+ raise RuntimeError (f"{ opt_path } lacks required key { key } . Required keys: { required_keys } " )
572+
573+ model_artifacts = {
574+ "names" : class_names ,
575+ "nc" : len (class_names ),
576+ "args" : {
577+ "imgsz" : opts ["imgsz" ] if "imgsz" in opts else opts ["img_size" ],
578+ "batch" : opts ["batch_size" ],
579+ "architecture" : opts ["architecture" ],
580+ },
581+ "model_type" : model_type ,
582+ }
583+
584+ with open (os .path .join (model_path , "model_artifacts.json" ), "w" ) as fp :
585+ json .dump (model_artifacts , fp )
586+
587+ shutil .copy (os .path .join (model_path , "weights/best.pt" ), os .path .join (model_path , "state_dict.pt" ))
588+
589+ list_files = [
590+ "results.json" ,
591+ "results.png" ,
592+ "model_artifacts.json" ,
593+ "state_dict.pt" ,
594+ ]
595+
596+ with zipfile .ZipFile (os .path .join (model_path , "roboflow_deploy.zip" ), "w" ) as zipMe :
597+ for file in list_files :
598+ if os .path .exists (os .path .join (model_path , file )):
599+ zipMe .write (
600+ os .path .join (model_path , file ),
601+ arcname = file ,
602+ compress_type = zipfile .ZIP_DEFLATED ,
603+ )
604+ else :
605+ if file in ["model_artifacts.json" , "best.pt" ]:
606+ raise (ValueError (f"File { file } not found. Please make sure to provide a" " valid model path." ))
607+
608+ self .upload_zip (model_type , model_path )
609+
610+ def upload_zip (self , model_type : str , model_path : str ):
538611 res = requests .get (
539612 f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } "
540613 f"/uploadModel?api_key={ self .__api_key } &modelType={ model_type } &nocache=true"
0 commit comments