2626from roboflow .models .instance_segmentation import InstanceSegmentationModel
2727from roboflow .models .object_detection import ObjectDetectionModel
2828from roboflow .models .semantic_segmentation import SemanticSegmentationModel
29+ from roboflow .util .annotations import amend_data_yaml
2930from roboflow .util .versions import (
31+ get_wrong_dependencies_versions ,
3032 print_warn_for_wrong_dependencies_versions ,
31- warn_for_wrong_dependencies_versions ,
3233)
3334
3435load_dotenv ()
@@ -549,7 +550,7 @@ def __get_format_identifier(self, format):
549550 friendly_formats = {"yolov5" : "yolov5pytorch" , "yolov7" : "yolov7pytorch" }
550551 return friendly_formats .get (format , format )
551552
552- def __reformat_yaml (self , location , format ):
553+ def __reformat_yaml (self , location : str , format : str ):
553554 """
554555 Certain formats seem to require reformatting the downloaded YAML.
555556 It'd be nice if the API did this, but we're doing it in python for now.
@@ -559,28 +560,29 @@ def __reformat_yaml(self, location, format):
559560
560561 :return None:
561562 """
562- if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
563- with open (location + "/data.yaml" ) as file :
564- new_yaml = yaml .safe_load (file )
565- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip (".." )
566- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip (".." )
567-
568- os .remove (location + "/data.yaml" )
569-
570- with open (location + "/data.yaml" , "w" ) as outfile :
571- yaml .dump (new_yaml , outfile )
572-
573- if format == "mt-yolov6" :
574- with open (location + "/data.yaml" ) as file :
575- new_yaml = yaml .safe_load (file )
576- new_yaml ["train" ] = location + new_yaml ["train" ].lstrip ("." )
577- new_yaml ["val" ] = location + new_yaml ["val" ].lstrip ("." )
578- new_yaml ["test" ] = location + new_yaml ["test" ].lstrip ("." )
579-
580- os .remove (location + "/data.yaml" )
563+ data_path = os .path .join (location , "data.yaml" )
564+
565+ def callback (content : dict ) -> dict :
566+ if format == "mt-yolov6" :
567+ content ["train" ] = location + content ["train" ].lstrip ("." )
568+ content ["val" ] = location + content ["val" ].lstrip ("." )
569+ content ["test" ] = location + content ["test" ].lstrip ("." )
570+ if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
571+ content ["train" ] = location + content ["train" ].lstrip (".." )
572+ content ["val" ] = location + content ["val" ].lstrip (".." )
573+ try :
574+ # get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
575+ if not get_wrong_dependencies_versions (
576+ dependencies_versions = [("ultralytics" , ">=" , "8.0.30" )]
577+ ):
578+ content ["train" ] = "train/images"
579+ content ["val" ] = "valid/images"
580+ content ["test" ] = "test/images"
581+ except ModuleNotFoundError :
582+ pass
583+ return content
581584
582- with open (location + "/data.yaml" , "w" ) as outfile :
583- yaml .dump (new_yaml , outfile )
585+ amend_data_yaml (path = data_path , callback = callback )
584586
585587 def __str__ (self ):
586588 """string representation of version object."""
0 commit comments