2929from roboflow .models .object_detection import ObjectDetectionModel
3030from roboflow .models .semantic_segmentation import SemanticSegmentationModel
3131from roboflow .util .general import write_line
32+ from roboflow .util .annotations import amend_data_yaml
3233from roboflow .util .versions import (
34+ get_wrong_dependencies_versions ,
3335 print_warn_for_wrong_dependencies_versions ,
34- warn_for_wrong_dependencies_versions ,
3536)
3637
3738load_dotenv ()
@@ -440,7 +441,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
440441 model_path (str): File path to model weights to be uploaded
441442 """
442443
443- supported_models = ["yolov8" , "yolov5" ]
444+ supported_models = ["yolov8" , "yolov5" , "yolov7-seg" ]
444445
445446 if model_type not in supported_models :
446447 raise (
@@ -463,7 +464,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
463464 [("ultralytics" , "<=" , "8.0.20" )]
464465 )
465466
466- elif model_type == "yolov5" :
467+ elif model_type in [ "yolov5" , "yolov7-seg" ] :
467468 try :
468469 import torch
469470 except ImportError as e :
@@ -510,7 +511,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
510511 "ultralytics_version" : ultralytics .__version__ ,
511512 "model_type" : model_type ,
512513 }
513- elif model_type == "yolov5" :
514+ elif model_type in [ "yolov5" , "yolov7-seg" ] :
514515 # parse from yaml for yolov5
515516
516517 with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
@@ -538,11 +539,19 @@ def deploy(self, model_type: str, model_path: str) -> None:
538539
539540 with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
540541 for file in lista_files :
541- zipMe .write (
542- model_path + file ,
543- arcname = file ,
544- compress_type = zipfile .ZIP_DEFLATED ,
545- )
542+ if os .path .exists (model_path + file ):
543+ zipMe .write (
544+ model_path + file ,
545+ arcname = file ,
546+ compress_type = zipfile .ZIP_DEFLATED ,
547+ )
548+ else :
549+ if file in ["model_artifacts.json" , "state_dict.pt" ]:
550+ raise (
551+ ValueError (
552+ f"File { file } not found. Please make sure to provide a valid model path."
553+ )
554+ )
546555
547556 res = requests .get (
548557 f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
@@ -681,7 +690,7 @@ def __get_format_identifier(self, format):
681690 friendly_formats = {"yolov5" : "yolov5pytorch" , "yolov7" : "yolov7pytorch" }
682691 return friendly_formats .get (format , format )
683692
684- def __reformat_yaml (self , location , format ):
693+ def __reformat_yaml (self , location : str , format : str ):
685694 """
686695 Certain formats seem to require reformatting the downloaded YAML.
687696 It'd be nice if the API did this, but we're doing it in python for now.
@@ -691,28 +700,30 @@ def __reformat_yaml(self, location, format):
691700
692701 :return None:
693702 """
694- if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
695- with open (location + "/data.yaml" ) as file :
696- new_yaml = yaml .safe_load (file )
697- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip (".." )
698- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip (".." )
699-
700- os .remove (location + "/data.yaml" )
701-
702- with open (location + "/data.yaml" , "w" ) as outfile :
703- yaml .dump (new_yaml , outfile )
704-
705- if format == "mt-yolov6" :
706- with open (location + "/data.yaml" ) as file :
707- new_yaml = yaml .safe_load (file )
708- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip ("." )
709- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip ("." )
710- new_yaml ["test" ] = location + new_yaml ["test" ].lstrip ("." )
711-
712- os .remove (location + "/data.yaml" )
703+ data_path = os .path .join (location , "data.yaml" )
704+
705+ def data_yaml_callback (content : dict ) -> dict :
706+ if format == "mt-yolov6" :
707+ content ["train" ] = location + content ["train" ].lstrip ("." )
708+ content ["val" ] = location + content ["val" ].lstrip ("." )
709+ content ["test" ] = location + content ["test" ].lstrip ("." )
710+ if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
711+ content ["train" ] = location + content ["train" ].lstrip (".." )
712+ content ["val" ] = location + content ["val" ].lstrip (".." )
713+ try :
714+ # get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
715+ if format == "yolov8" and not get_wrong_dependencies_versions (
716+ dependencies_versions = [("ultralytics" , ">=" , "8.0.30" )]
717+ ):
718+ content ["train" ] = "train/images"
719+ content ["val" ] = "valid/images"
720+ content ["test" ] = "test/images"
721+ except ModuleNotFoundError :
722+ pass
723+ return content
713724
714- with open ( location + "/data.yaml " , "w" ) as outfile :
715- yaml . dump ( new_yaml , outfile )
725+ if format in [ "yolov5pytorch" , "mt-yolov6 " , "yolov7pytorch" , "yolov8" ] :
726+ amend_data_yaml ( path = data_path , callback = data_yaml_callback )
716727
717728 def __str__ (self ):
718729 """string representation of version object."""
0 commit comments