@@ -441,16 +441,18 @@ def deploy(self, model_type: str, model_path: str) -> None:
441441 model_path (str): File path to model weights to be uploaded
442442 """
443443
444- supported_models = ["yolov8" , " yolov5" , "yolov7-seg" ]
444+ supported_models = ["yolov5" , "yolov7-seg" , "yolov8 " ]
445445
446- if model_type not in supported_models :
446+ if not any (
447+ supported_model in model_type for supported_model in supported_models
448+ ):
447449 raise (
448450 ValueError (
449451 f"Model type { model_type } not supported. Supported models are { supported_models } "
450452 )
451453 )
452454
453- if model_type == "yolov8" :
455+ if "yolov8" in model_type :
454456 try :
455457 import torch
456458 import ultralytics
@@ -464,7 +466,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
464466 [("ultralytics" , "<=" , "8.0.20" )]
465467 )
466468
467- elif model_type in [ "yolov5" , "yolov7-seg" ] :
469+ elif "yolov5" in model_type or "yolov7" in model_type :
468470 try :
469471 import torch
470472 except ImportError as e :
@@ -483,16 +485,22 @@ def deploy(self, model_type: str, model_path: str) -> None:
483485 class_names .sort (key = lambda x : x [0 ])
484486 class_names = [x [1 ] for x in class_names ]
485487
486- if model_type == "yolov8" :
488+ if "yolov8" in model_type :
487489 # try except for backwards compatibility with older versions of ultralytics
490+ if "-cls" in model_type :
491+ nc = model ["model" ].yaml ["nc" ]
492+ args = model ["train_args" ]
493+ else :
494+ nc = model ["model" ].nc
495+ args = model ["model" ].args
488496 try :
489497 model_artifacts = {
490498 "names" : class_names ,
491499 "yaml" : model ["model" ].yaml ,
492- "nc" : model [ "model" ]. nc ,
500+ "nc" : nc ,
493501 "args" : {
494502 k : val
495- for k , val in model [ "model" ]. args .items ()
503+ for k , val in args .items ()
496504 if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
497505 },
498506 "ultralytics_version" : ultralytics .__version__ ,
@@ -502,33 +510,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
502510 model_artifacts = {
503511 "names" : class_names ,
504512 "yaml" : model ["model" ].yaml ,
505- "nc" : model [ "model" ]. nc ,
513+ "nc" : nc ,
506514 "args" : {
507515 k : val
508- for k , val in model [ "model" ]. args .__dict__ .items ()
516+ for k , val in args .__dict__ .items ()
509517 if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
510518 },
511519 "ultralytics_version" : ultralytics .__version__ ,
512520 "model_type" : model_type ,
513521 }
514- elif model_type in [ "yolov5" , "yolov7-seg" ] :
522+ elif "yolov5" in model_type or "yolov7" in model_type :
515523 # parse from yaml for yolov5
516524
517525 with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
518526 opts = yaml .safe_load (stream )
519527
520528 model_artifacts = {
521529 "names" : class_names ,
522- "yaml" : model ["model" ].yaml ,
523530 "nc" : model ["model" ].nc ,
524- "args" : {"imgsz" : opts ["imgsz" ], "batch" : opts ["batch_size" ]},
531+ "args" : {
532+ "imgsz" : opts ["imgsz" ] if "imgsz" in opts else opts ["img_size" ],
533+ "batch" : opts ["batch_size" ],
534+ },
525535 "model_type" : model_type ,
526536 }
537+ if hasattr (model ["model" ], "yaml" ):
538+ model_artifacts ["yaml" ] = model ["model" ].yaml
527539
528- with open (model_path + "model_artifacts.json" , "w" ) as fp :
540+ with open (os . path . join ( model_path , "model_artifacts.json" ) , "w" ) as fp :
529541 json .dump (model_artifacts , fp )
530542
531- torch .save (model ["model" ].state_dict (), model_path + "state_dict.pt" )
543+ torch .save (
544+ model ["model" ].state_dict (), os .path .join (model_path , "state_dict.pt" )
545+ )
532546
533547 lista_files = [
534548 "results.csv" ,
@@ -537,11 +551,13 @@ def deploy(self, model_type: str, model_path: str) -> None:
537551 "state_dict.pt" ,
538552 ]
539553
540- with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
554+ with zipfile .ZipFile (
555+ os .path .join (model_path , "roboflow_deploy.zip" ), "w"
556+ ) as zipMe :
541557 for file in lista_files :
542- if os .path .exists (model_path + file ):
558+ if os .path .exists (os . path . join ( model_path , file ) ):
543559 zipMe .write (
544- model_path + file ,
560+ os . path . join ( model_path , file ) ,
545561 arcname = file ,
546562 compress_type = zipfile .ZIP_DEFLATED ,
547563 )
@@ -554,7 +570,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
554570 )
555571
556572 res = requests .get (
557- f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
573+ f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } &modelType= { model_type } "
558574 )
559575 try :
560576 if res .status_code == 429 :
@@ -569,7 +585,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
569585
570586 res = requests .put (
571587 res .json ()["url" ],
572- data = open (os .path .join (model_path + "roboflow_deploy.zip" ), "rb" ),
588+ data = open (os .path .join (model_path , "roboflow_deploy.zip" ), "rb" ),
573589 )
574590 try :
575591 res .raise_for_status ()
0 commit comments