2626from roboflow .models .instance_segmentation import InstanceSegmentationModel
2727from roboflow .models .object_detection import ObjectDetectionModel
2828from roboflow .models .semantic_segmentation import SemanticSegmentationModel
29+ from roboflow .util .annotations import amend_data_yaml
2930from roboflow .util .versions import (
31+ get_wrong_dependencies_versions ,
3032 print_warn_for_wrong_dependencies_versions ,
31- warn_for_wrong_dependencies_versions ,
3233)
3334
3435load_dotenv ()
@@ -311,7 +312,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
311312 model_path (str): File path to model weights to be uploaded
312313 """
313314
314- supported_models = ["yolov8" , "yolov5" ]
315+ supported_models = ["yolov8" , "yolov5" , "yolov7-seg" ]
315316
316317 if model_type not in supported_models :
317318 raise (
@@ -334,7 +335,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
334335 [("ultralytics" , "<=" , "8.0.20" )]
335336 )
336337
337- elif model_type == "yolov5" :
338+ elif model_type in [ "yolov5" , "yolov7-seg" ] :
338339 try :
339340 import torch
340341 except ImportError as e :
@@ -378,7 +379,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
378379 "ultralytics_version" : ultralytics .__version__ ,
379380 "model_type" : model_type ,
380381 }
381- elif model_type == "yolov5" :
382+ elif model_type in [ "yolov5" , "yolov7-seg" ] :
382383 # parse from yaml for yolov5
383384
384385 with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
@@ -406,11 +407,19 @@ def deploy(self, model_type: str, model_path: str) -> None:
406407
407408 with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
408409 for file in lista_files :
409- zipMe .write (
410- model_path + file ,
411- arcname = file ,
412- compress_type = zipfile .ZIP_DEFLATED ,
413- )
410+ if os .path .exists (model_path + file ):
411+ zipMe .write (
412+ model_path + file ,
413+ arcname = file ,
414+ compress_type = zipfile .ZIP_DEFLATED ,
415+ )
416+ else :
417+ if file in ["model_artifacts.json" , "state_dict.pt" ]:
418+ raise (
419+ ValueError (
420+ f"File { file } not found. Please make sure to provide a valid model path."
421+ )
422+ )
414423
415424 res = requests .get (
416425 f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
@@ -549,7 +558,7 @@ def __get_format_identifier(self, format):
549558 friendly_formats = {"yolov5" : "yolov5pytorch" , "yolov7" : "yolov7pytorch" }
550559 return friendly_formats .get (format , format )
551560
552- def __reformat_yaml (self , location , format ):
561+ def __reformat_yaml (self , location : str , format : str ):
553562 """
554563 Certain formats seem to require reformatting the downloaded YAML.
555564 It'd be nice if the API did this, but we're doing it in python for now.
@@ -559,28 +568,29 @@ def __reformat_yaml(self, location, format):
559568
560569 :return None:
561570 """
562- if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
563- with open (location + "/data.yaml" ) as file :
564- new_yaml = yaml .safe_load (file )
565- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip (".." )
566- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip (".." )
567-
568- os .remove (location + "/data.yaml" )
569-
570- with open (location + "/data.yaml" , "w" ) as outfile :
571- yaml .dump (new_yaml , outfile )
572-
573- if format == "mt-yolov6" :
574- with open (location + "/data.yaml" ) as file :
575- new_yaml = yaml .safe_load (file )
576- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip ("." )
577- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip ("." )
578- new_yaml ["test" ] = location + new_yaml ["test" ].lstrip ("." )
579-
580- os .remove (location + "/data.yaml" )
571+ data_path = os .path .join (location , "data.yaml" )
572+
573+ def callback (content : dict ) -> dict :
574+ if format == "mt-yolov6" :
575+ content ["train" ] = location + content ["train" ].lstrip ("." )
576+ content ["val" ] = location + content ["val" ].lstrip ("." )
577+ content ["test" ] = location + content ["test" ].lstrip ("." )
578+ if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
579+ content ["train" ] = location + content ["train" ].lstrip (".." )
580+ content ["val" ] = location + content ["val" ].lstrip (".." )
581+ try :
582+ # get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
583+ if not get_wrong_dependencies_versions (
584+ dependencies_versions = [("ultralytics" , ">=" , "8.0.30" )]
585+ ):
586+ content ["train" ] = "train/images"
587+ content ["val" ] = "valid/images"
588+ content ["test" ] = "test/images"
589+ except ModuleNotFoundError :
590+ pass
591+ return content
581592
582- with open (location + "/data.yaml" , "w" ) as outfile :
583- yaml .dump (new_yaml , outfile )
593+ amend_data_yaml (path = data_path , callback = callback )
584594
585595 def __str__ (self ):
586596 """string representation of version object."""
0 commit comments