Skip to content

Commit 07cb243

Browse files
Merge pull request #101 from roboflow/yolov5Upload
Upload YOLOv5 Models to Roboflow Deploy
2 parents e0b8ce3 + 4207cbf commit 07cb243

File tree

6 files changed

+193
-64
lines changed

6 files changed

+193
-64
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from roboflow.core.project import Project
99
from roboflow.core.workspace import Workspace
1010

11-
__version__ = "0.2.27"
11+
__version__ = "0.2.28"
1212

1313

1414
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from roboflow.models.instance_segmentation import InstanceSegmentationModel
2626
from roboflow.models.object_detection import ObjectDetectionModel
2727
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
28+
from roboflow.util.versions import (
29+
print_warn_for_wrong_dependencies_versions,
30+
warn_for_wrong_dependencies_versions,
31+
)
2832

2933
load_dotenv()
3034

@@ -153,6 +157,12 @@ def download(self, model_format=None, location=None, overwrite: bool = True):
153157

154158
self.__wait_if_generating()
155159

160+
if model_format == "yolov8":
161+
# we assume the user will want to use yolov8, for now we only support ultralytics=="8.11.0"
162+
print_warn_for_wrong_dependencies_versions(
163+
[("ultralytics", "<=", "8.0.20")]
164+
)
165+
156166
model_format = self.__get_format_identifier(model_format)
157167

158168
if model_format not in self.exports:
@@ -284,14 +294,15 @@ def train(self, speed=None, checkpoint=None) -> bool:
284294

285295
return True
286296

297+
# @warn_for_wrong_dependencies_versions([("ultralytics", "<=", "8.0.20")])
287298
def deploy(self, model_type: str, model_path: str) -> None:
288299
"""Uploads provided weights file to Roboflow
289300
290301
Args:
291302
model_path (str): File path to model weights to be uploaded
292303
"""
293304

294-
supported_models = ["yolov8"]
305+
supported_models = ["yolov8", "yolov5"]
295306

296307
if model_type not in supported_models:
297308
raise (
@@ -300,24 +311,38 @@ def deploy(self, model_type: str, model_path: str) -> None:
300311
)
301312
)
302313

303-
try:
304-
import torch
305-
import ultralytics
306-
except ImportError as e:
307-
raise (
308-
"The ultralytics python package is required to deploy yolov8 models. Please install it with `pip install ultralytics`"
314+
if model_type == "yolov8":
315+
try:
316+
import torch
317+
import ultralytics
318+
319+
except ImportError as e:
320+
raise (
321+
"The ultralytics python package is required to deploy yolov8 models. Please install it with `pip install ultralytics`"
322+
)
323+
324+
print_warn_for_wrong_dependencies_versions(
325+
[("ultralytics", "<=", "8.0.20")]
309326
)
310327

311-
# add logic to save torch state dict safely
312-
if model_type == "yolov8":
313-
model = torch.load(os.path.join(model_path + "weights/best.pt"))
328+
elif model_type == "yolov5":
329+
try:
330+
import torch
331+
except ImportError as e:
332+
raise (
333+
"The torch python package is required to deploy yolov5 models. Please install it with `pip install torch`"
334+
)
335+
336+
model = torch.load(os.path.join(model_path, "weights/best.pt"))
314337

315-
class_names = []
316-
for i, val in enumerate(model["model"].names):
317-
class_names.append((val, model["model"].names[val]))
318-
class_names.sort(key=lambda x: x[0])
319-
class_names = [x[1] for x in class_names]
338+
class_names = []
339+
for i, val in enumerate(model["model"].names):
340+
class_names.append((val, model["model"].names[val]))
341+
class_names.sort(key=lambda x: x[0])
342+
class_names = [x[1] for x in class_names]
320343

344+
if model_type == "yolov8":
345+
# try except for backwards compatibility with older versions of ultralytics
321346
try:
322347
model_artifacts = {
323348
"names": class_names,
@@ -344,29 +369,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
344369
"ultralytics_version": ultralytics.__version__,
345370
"model_type": model_type,
346371
}
347-
348-
with open(os.path.join(model_path + "model_artifacts.json", "w")) as fp:
349-
json.dump(model_artifacts, fp)
350-
351-
torch.save(
352-
model["model"].state_dict(), os.path.join(model_path + "state_dict.pt")
353-
)
354-
355-
lista_files = [
356-
"results.csv",
357-
"results.png",
358-
"model_artifacts.json",
359-
"state_dict.pt",
360-
]
361-
with zipfile.ZipFile(
362-
os.path.join(model_path + "roboflow_deploy.zip", "w")
363-
) as zipMe:
364-
for file in lista_files:
365-
zipMe.write(
366-
os.path.join(model_path + file),
367-
arcname=file,
368-
compress_type=zipfile.ZIP_DEFLATED,
369-
)
372+
elif model_type == "yolov5":
373+
# parse from yaml for yolov5
374+
375+
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
376+
opts = yaml.safe_load(stream)
377+
378+
model_artifacts = {
379+
"names": class_names,
380+
"yaml": model["model"].yaml,
381+
"nc": model["model"].nc,
382+
"args": {"imgsz": opts["imgsz"], "batch": opts["batch_size"]},
383+
"model_type": model_type,
384+
}
385+
386+
with open(model_path + "model_artifacts.json", "w") as fp:
387+
json.dump(model_artifacts, fp)
388+
389+
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
390+
391+
lista_files = [
392+
"results.csv",
393+
"results.png",
394+
"model_artifacts.json",
395+
"state_dict.pt",
396+
]
397+
398+
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
399+
for file in lista_files:
400+
zipMe.write(
401+
model_path + file,
402+
arcname=file,
403+
compress_type=zipfile.ZIP_DEFLATED,
404+
)
370405

371406
res = requests.get(
372407
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
@@ -384,7 +419,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
384419

385420
res = requests.put(
386421
res.json()["url"],
387-
data=open(os.path.join(model_path + "roboflow_deploy.zip", "rb")),
422+
data=open(os.path.join(model_path + "roboflow_deploy.zip"), "rb"),
388423
)
389424
try:
390425
res.raise_for_status()
@@ -404,8 +439,6 @@ def deploy(self, model_type: str, model_path: str) -> None:
404439
except Exception as e:
405440
print(f"An error occured when uploading the model: {e}")
406441

407-
# torch.load("state_dict.pt", weights_only=True)
408-
409442
def __download_zip(self, link, location, format):
410443
"""
411444
Download a dataset's zip file from the given URL and save it in the desired location

roboflow/util/versions.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from importlib import import_module
2+
from typing import List, Tuple
3+
4+
5+
def get_wrong_dependencies_versions(
6+
dependencies_versions: List[Tuple[str, str, str]]
7+
) -> List[Tuple[str, str, str, str]]:
8+
"""
9+
Get a list of missmatching dependencies with current version installed.
10+
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
11+
12+
We support `<=`, `==`, `>=`
13+
14+
Args:
15+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
16+
17+
Returns:
18+
List[Tuple[str, str, str]]: List of dependencies with wrong version, [("<package_name>", "<version_number_to_check", "<current_version>")]
19+
"""
20+
wrong_dependencies_versions = []
21+
# from e.g. "1.12.0" -> 1120
22+
parse_version_to_int = lambda x: int(x.replace(".", ""))
23+
order_funcs = {
24+
"==": lambda x, y: x == y,
25+
">=": lambda x, y: x >= y,
26+
"<=": lambda x, y: x <= y,
27+
}
28+
for dependency, order, version in dependencies_versions:
29+
module = import_module(dependency)
30+
module_version = module.__version__
31+
if order not in order_funcs:
32+
raise ValueError(
33+
f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`"
34+
)
35+
is_okay = order_funcs[order](
36+
parse_version_to_int(module_version), parse_version_to_int(version)
37+
)
38+
if not is_okay:
39+
wrong_dependencies_versions.append(
40+
(dependency, order, version, module_version)
41+
)
42+
return wrong_dependencies_versions
43+
44+
45+
def print_warn_for_wrong_dependencies_versions(
46+
dependencies_versions: List[Tuple[str, str, str]]
47+
):
48+
wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions)
49+
for (dependency, order, version, module_version) in wrong_dependencies_versions:
50+
print(
51+
f"Dependency {dependency}{order}{version} is required but found version={module_version}, to fix: `pip install {dependency}{order}{version}`"
52+
)
53+
54+
55+
def warn_for_wrong_dependencies_versions(
56+
dependencies_versions: List[Tuple[str, str, str]]
57+
):
58+
"""
59+
Decorator to print a warning based on dependencies versions. E.g.
60+
61+
```python
62+
@warn_for_wrong_dependencies_versions([("torch", "==", "1.2.0")])
63+
def foo(x):
64+
# I only work with torch `1.2.0` but another one is installed
65+
print(f"foo {x}")
66+
```
67+
68+
prints:
69+
70+
```
71+
Dependency torch==1.2.0 is required but found version=1.13.1, to fix: `pip install torch==1.2.0`
72+
```
73+
74+
Args:
75+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
76+
"""
77+
78+
def _inner(func):
79+
def _wrapper(*args, **kwargs):
80+
print_warn_for_wrong_dependencies_versions(dependencies_versions)
81+
func(*args, **kwargs)
82+
83+
return _wrapper
84+
85+
return _inner

setup.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
with open("README.md", "r") as fh:
1212
long_description = fh.read()
1313

14+
# we are using the packages in `requirements.txt` for now,
15+
# not 100% ideal but will do
16+
with open("requirements.txt", "r") as fh:
17+
install_requires = fh.read().split('\n')
18+
1419
setuptools.setup(
1520
name="roboflow",
1621
version=version,
@@ -20,28 +25,7 @@
2025
long_description=long_description,
2126
long_description_content_type="text/markdown",
2227
url="https://github.com/roboflow-ai/roboflow-python",
23-
install_requires=[
24-
"certifi==2022.12.7",
25-
"chardet==4.0.0",
26-
"cycler==0.10.0",
27-
"glob2",
28-
"idna==2.10",
29-
"kiwisolver>=1.3.1",
30-
"matplotlib",
31-
"numpy>=1.18.5",
32-
"opencv-python-headless>=4.5.1.48",
33-
"Pillow>=7.1.2",
34-
"pyparsing==2.4.7",
35-
"python-dateutil",
36-
"python-dotenv",
37-
"requests",
38-
"requests_toolbelt",
39-
"six",
40-
"urllib3==1.26.6",
41-
"tqdm>=4.41.0",
42-
"PyYAML>=5.3.1",
43-
"wget",
44-
],
28+
install_requires=install_requires,
4529
packages=find_packages(exclude=("tests",)),
4630
extras_require={
4731
"dev": ["flake8", "black==22.3.0", "isort", "responses", "twine", "wheel"],
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.3.0"

tests/util/test_versions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from importlib import import_module
2+
from roboflow.util.versions import get_wrong_dependencies_versions
3+
import unittest
4+
import sys
5+
from pathlib import Path
6+
7+
class TestVersions(unittest.TestCase):
8+
9+
def test_wrong_dependencies_versions(self):
10+
module_path = "tests.util.dummy_module"
11+
module_version = import_module(module_path).__version__
12+
tests = [
13+
("tests.util.dummy_module", "==", module_version),
14+
("tests.util.dummy_module", "<=", "0.2.0"),
15+
("tests.util.dummy_module", "<=", "1.0.0"),
16+
("tests.util.dummy_module", ">=", "0.1.0"),
17+
("tests.util.dummy_module", ">=", "0.6.0")
18+
19+
]
20+
# true if dep is correc
21+
expected_results = [True, False, True, True, False]
22+
23+
for (test, expected_result) in zip(tests, expected_results):
24+
wrong_dependencies_versions = get_wrong_dependencies_versions([test])
25+
is_correct_dep = len(wrong_dependencies_versions) == 0
26+
self.assertEqual(is_correct_dep, expected_result)

0 commit comments

Comments
 (0)