Skip to content

Commit 4055a0e

Browse files
authored
Merge pull request #407 from roboflow/rodrigo/fix-cli-format-export
Fixing format export for CLI/python sdk
2 parents f561f05 + f1a4dc3 commit 4055a0e

File tree

4 files changed

+61
-17
lines changed

4 files changed

+61
-17
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.2.5"
18+
__version__ = "1.2.6"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from roboflow.util.annotations import amend_data_yaml
3535
from roboflow.util.general import write_line
3636
from roboflow.util.model_processor import process
37-
from roboflow.util.versions import get_wrong_dependencies_versions, normalize_yolo_model_type
37+
from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type
3838

3939
if TYPE_CHECKING:
4040
import numpy as np
@@ -244,7 +244,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
244244

245245
return Dataset(self.name, self.version, model_format, os.path.abspath(location))
246246

247-
def export(self, model_format=None):
247+
def export(self, model_format=None) -> bool | None:
248248
"""
249249
Ask the Roboflow API to generate a version's dataset in a given format so that it can be downloaded via the `download()` method.
250250
@@ -254,7 +254,7 @@ def export(self, model_format=None):
254254
model_format (str): A format to use for downloading
255255
256256
Returns:
257-
True
257+
True if the export was successful, RuntimeError if the export failed
258258
259259
Raises:
260260
RuntimeError: If the Roboflow API returns an error with a helpful JSON body
@@ -316,18 +316,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
316316

317317
self.__wait_if_generating()
318318

319-
train_model_format = "yolov5pytorch"
320-
321-
if self.type == TYPE_CLASSICATION:
322-
train_model_format = "folder"
323-
324-
if self.type == TYPE_INSTANCE_SEGMENTATION:
325-
train_model_format = "yolov5pytorch"
326-
327-
if self.type == TYPE_SEMANTIC_SEGMENTATION:
328-
train_model_format = "png-mask-semantic"
329-
330-
# if classification
319+
train_model_format = get_model_format(model_type)
331320
if train_model_format not in self.exports:
332321
self.export(train_model_format)
333322

roboflow/util/versions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,40 @@ def normalize_yolo_model_type(model_type: str) -> str:
9595
model_type = model_type.replace("yolo11", "yolov11")
9696
model_type = model_type.replace("yolo12", "yolov12")
9797
return model_type
98+
99+
100+
def get_model_format(model_type: str) -> str:
101+
"""
102+
Get the model format for a given model type.
103+
Args:
104+
model_type (str): The model type to get the format for.
105+
106+
Returns:
107+
str: The model format.
108+
109+
Example:
110+
>>> get_model_format("yolov5v6n")
111+
"yolov5pytorch"
112+
>>> get_model_format("rfdetr-nano")
113+
"coco"
114+
>>> get_model_format("yolov11n")
115+
"yolov5pytorch"
116+
"""
117+
# Prefixes extrated from modelRegistry.js in roboflow.
118+
model_formats = {
119+
"yolo": "yolov5pytorch",
120+
"pali": "jsonl",
121+
"flor": "jsonl",
122+
"qwen": "jsonl",
123+
"smol": "jsonl",
124+
"vit-b": "folder",
125+
"resn": "folder",
126+
"rfdetr": "coco",
127+
"rf-detr": "coco",
128+
"deep": "png-mask-semantic",
129+
}
130+
131+
for prefix, format in model_formats.items():
132+
if prefix in model_type:
133+
return format
134+
return "yolov5pytorch"

tests/util/test_versions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from importlib import import_module
33

4-
from roboflow.util.versions import get_wrong_dependencies_versions
4+
from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions
55

66

77
class TestVersions(unittest.TestCase):
@@ -23,3 +23,21 @@ def test_wrong_dependencies_versions(self):
2323
wrong_dependencies_versions = get_wrong_dependencies_versions([test])
2424
is_correct_dep = len(wrong_dependencies_versions) == 0
2525
self.assertEqual(is_correct_dep, expected_result)
26+
27+
28+
class TestGetModelFormat(unittest.TestCase):
29+
def test_get_model_format_with_various_ids(self):
30+
cases = [
31+
("yolov5v2s", "yolov5pytorch"),
32+
("yolov11n", "yolov5pytorch"),
33+
("rf-detr-nas-parent", "coco"),
34+
("rfdetr-nano", "coco"),
35+
("vit-base-patch16-224-in21k", "folder"),
36+
("resnet14", "folder"),
37+
("resenet38", "yolov5pytorch"),
38+
("invlid-type", "yolov5pytorch"),
39+
]
40+
41+
for model_type, expected_format in cases:
42+
with self.subTest(model_type=model_type):
43+
self.assertEqual(get_model_format(model_type), expected_format)

0 commit comments

Comments
 (0)