Skip to content

Commit c56f6a4

Browse files
authored
Merge pull request #113 from roboflow/fix/make-roboflow-datasets-work-with-yolov8-codebase-once-again
🧪 ready for colab tests
2 parents 9246b1a + a3c42c3 commit c56f6a4

File tree

4 files changed

+40
-25
lines changed

4 files changed

+40
-25
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from roboflow.core.project import Project
99
from roboflow.core.workspace import Workspace
1010

11-
__version__ = "0.2.29"
11+
__version__ = "0.2.30"
1212

1313

1414
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from roboflow.models.instance_segmentation import InstanceSegmentationModel
2727
from roboflow.models.object_detection import ObjectDetectionModel
2828
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
29+
from roboflow.util.annotations import amend_data_yaml
2930
from roboflow.util.versions import (
31+
get_wrong_dependencies_versions,
3032
print_warn_for_wrong_dependencies_versions,
31-
warn_for_wrong_dependencies_versions,
3233
)
3334

3435
load_dotenv()
@@ -549,7 +550,7 @@ def __get_format_identifier(self, format):
549550
friendly_formats = {"yolov5": "yolov5pytorch", "yolov7": "yolov7pytorch"}
550551
return friendly_formats.get(format, format)
551552

552-
def __reformat_yaml(self, location, format):
553+
def __reformat_yaml(self, location: str, format: str):
553554
"""
554555
Certain formats seem to require reformatting the downloaded YAML.
555556
It'd be nice if the API did this, but we're doing it in python for now.
@@ -559,28 +560,29 @@ def __reformat_yaml(self, location, format):
559560
560561
:return None:
561562
"""
562-
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
563-
with open(location + "/data.yaml") as file:
564-
new_yaml = yaml.safe_load(file)
565-
new_yaml["train"] = location + new_yaml["train"].lstrip("..")
566-
new_yaml["val"] = location + new_yaml["val"].lstrip("..")
567-
568-
os.remove(location + "/data.yaml")
569-
570-
with open(location + "/data.yaml", "w") as outfile:
571-
yaml.dump(new_yaml, outfile)
572-
573-
if format == "mt-yolov6":
574-
with open(location + "/data.yaml") as file:
575-
new_yaml = yaml.safe_load(file)
576-
new_yaml["train"] = location + new_yaml["train"].lstrip(".")
577-
new_yaml["val"] = location + new_yaml["val"].lstrip(".")
578-
new_yaml["test"] = location + new_yaml["test"].lstrip(".")
579-
580-
os.remove(location + "/data.yaml")
563+
data_path = os.path.join(location, "data.yaml")
564+
565+
def callback(content: dict) -> dict:
566+
if format == "mt-yolov6":
567+
content["train"] = location + content["train"].lstrip(".")
568+
content["val"] = location + content["val"].lstrip(".")
569+
content["test"] = location + content["test"].lstrip(".")
570+
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
571+
content["train"] = location + content["train"].lstrip("..")
572+
content["val"] = location + content["val"].lstrip("..")
573+
try:
574+
# get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
575+
if not get_wrong_dependencies_versions(
576+
dependencies_versions=[("ultralytics", ">=", "8.0.30")]
577+
):
578+
content["train"] = "train/images"
579+
content["val"] = "valid/images"
580+
content["test"] = "test/images"
581+
except ModuleNotFoundError:
582+
pass
583+
return content
581584

582-
with open(location + "/data.yaml", "w") as outfile:
583-
yaml.dump(new_yaml, outfile)
585+
amend_data_yaml(path=data_path, callback=callback)
584586

585587
def __str__(self):
586588
"""string representation of version object."""

roboflow/util/annotations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
from typing import Callable
3+
4+
import yaml
5+
6+
7+
def amend_data_yaml(path: str, callback: Callable[[dict], dict]):
8+
with open(path) as source:
9+
content = yaml.safe_load(source)
10+
content = callback(content)
11+
os.remove(path)
12+
with open(path, "w") as target:
13+
yaml.dump(content, target)

roboflow/util/versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def get_wrong_dependencies_versions(
88
dependencies_versions: List[Tuple[str, str, str]]
99
) -> List[Tuple[str, str, str, str]]:
1010
"""
11-
Get a list of missmatching dependencies with current version installed.
11+
Get a list of mismatching dependencies with current version installed.
1212
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
1313
1414
We support `<=`, `==`, `>=`

0 commit comments

Comments
 (0)