Skip to content

Commit 8e86288

Browse files
Merge branch 'main' into json-response-trt-fix
2 parents 9e2b77e + b5495c1 commit 8e86288

File tree

5 files changed

+95
-7
lines changed

5 files changed

+95
-7
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from roboflow.core.project import Project
99
from roboflow.core.workspace import Workspace
1010

11-
__version__ = "0.2.24"
12-
11+
__version__ = "0.2.25"
1312

1413
def check_key(api_key, model, notebook, num_retries=0):
1514
if type(api_key) is not str:

roboflow/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
API_URL = os.getenv("API_URL", "https://api.roboflow.com")
1010
APP_URL = os.getenv("APP_URL", "https://app.roboflow.com")
11+
UNIVERSE_URL = os.getenv("UNIVERSE_URL", "https://universe.roboflow.com")
1112
INSTANCE_SEGMENTATION_URL = os.getenv(
1213
"INSTANCE_SEGMENTATION_URL", "https://outline.roboflow.com"
1314
)

roboflow/core/project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def versions(self):
9393
local=None,
9494
workspace=self.__workspace,
9595
project=self.__project_name,
96+
public=self.public,
9697
)
9798
version_array.append(version_object)
9899
return version_array
@@ -251,6 +252,7 @@ def version(self, version_number, local=None):
251252
local=local,
252253
workspace=self.__workspace,
253254
project=self.__project_name,
255+
public=self.public,
254256
)
255257
return vers
256258

roboflow/core/version.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212

1313
from roboflow.config import (
1414
API_URL,
15+
APP_URL,
1516
DEMO_KEYS,
1617
TYPE_CLASSICATION,
1718
TYPE_INSTANCE_SEGMENTATION,
1819
TYPE_OBJECT_DETECTION,
1920
TYPE_SEMANTIC_SEGMENTATION,
21+
UNIVERSE_URL,
2022
)
2123
from roboflow.core.dataset import Dataset
2224
from roboflow.models.classification import ClassificationModel
@@ -39,6 +41,7 @@ def __init__(
3941
local,
4042
workspace,
4143
project,
44+
public,
4245
):
4346
if api_key in DEMO_KEYS:
4447
if api_key == "coco-128-sample":
@@ -70,6 +73,7 @@ def __init__(
7073
self.model_format = model_format
7174
self.workspace = workspace
7275
self.project = project
76+
self.public = public
7377
if "exports" in version_dict.keys():
7478
self.exports = version_dict["exports"]
7579
else:
@@ -136,12 +140,13 @@ def __wait_if_generating(self, recurse=False):
136140
sys.stdout.flush()
137141
return
138142

139-
def download(self, model_format=None, location=None):
143+
def download(self, model_format=None, location=None, overwrite: bool = True):
140144
"""
141145
Download and extract a ZIP of a version's dataset in a given format
142146
143147
:param model_format: A format to use for downloading
144148
:param location: An optional path for saving the file
149+
:param overwrite: An optional flag to prevent dataset overwrite when dataset is already downloaded
145150
146151
:return: Dataset
147152
"""
@@ -157,6 +162,10 @@ def download(self, model_format=None, location=None):
157162

158163
if location is None:
159164
location = self.__get_download_location()
165+
if os.path.exists(location) and not overwrite:
166+
return Dataset(
167+
self.name, self.version, model_format, os.path.abspath(location)
168+
)
160169

161170
if self.__api_key == "coco-128-sample":
162171
link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
@@ -275,12 +284,70 @@ def train(self, speed=None, checkpoint=None) -> bool:
275284

276285
return True
277286

278-
def upload_model(self, model_path: str) -> None:
287+
def deploy(self, model_type: str, model_path: str) -> None:
279288
"""Uploads provided weights file to Roboflow
280289
281290
Args:
282291
model_path (str): File path to model weights to be uploaded
283292
"""
293+
294+
supported_models = ["yolov8"]
295+
296+
if model_type not in supported_models:
297+
raise (
298+
ValueError(
299+
f"Model type {model_type} not supported. Supported models are {supported_models}"
300+
)
301+
)
302+
303+
try:
304+
import torch
305+
import ultralytics
306+
except ImportError as e:
307+
raise (
308+
"The ultralytics python package is required to deploy yolov8 models. Please install it with `pip install ultralytics`"
309+
)
310+
311+
# add logic to save torch state dict safely
312+
if model_type == "yolov8":
313+
model = torch.load(model_path + "weights/best.pt")
314+
315+
class_names = []
316+
for i, val in enumerate(model["model"].names):
317+
class_names.append((val, model["model"].names[val]))
318+
class_names.sort(key=lambda x: x[0])
319+
class_names = [x[1] for x in class_names]
320+
321+
model_artifacts = {
322+
"names": class_names,
323+
"yaml": model["model"].yaml,
324+
"nc": model["model"].nc,
325+
"args": {
326+
k: val for k, val in model["model"].args.items() if k != "hydra"
327+
},
328+
"ultralytics_version": ultralytics.__version__,
329+
"model_type": model_type,
330+
}
331+
332+
with open(model_path + "model_artifacts.json", "w") as fp:
333+
json.dump(model_artifacts, fp)
334+
335+
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
336+
337+
lista_files = [
338+
"results.csv",
339+
"results.png",
340+
"model_artifacts.json",
341+
"state_dict.pt",
342+
]
343+
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
344+
for file in lista_files:
345+
zipMe.write(
346+
model_path + file,
347+
arcname=file,
348+
compress_type=zipfile.ZIP_DEFLATED,
349+
)
350+
284351
res = requests.get(
285352
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
286353
)
@@ -294,13 +361,30 @@ def upload_model(self, model_path: str) -> None:
294361
except Exception as e:
295362
print(f"An error occured when getting the model upload URL: {e}")
296363
return
297-
res = requests.put(res.json()["url"], data=open(model_path, "rb"))
364+
365+
res = requests.put(
366+
res.json()["url"], data=open(model_path + "roboflow_deploy.zip", "rb")
367+
)
298368
try:
299369
res.raise_for_status()
300-
print("Model uploaded")
370+
371+
if self.public:
372+
print(
373+
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/deploy/{self.version}"
374+
)
375+
print(
376+
f"Share your model with the world at: {UNIVERSE_URL}/{self.workspace}/{self.project}/model/{self.version}"
377+
)
378+
else:
379+
print(
380+
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/deploy/{self.version}"
381+
)
382+
301383
except Exception as e:
302384
print(f"An error occured when uploading the model: {e}")
303385

386+
# torch.load("state_dict.pt", weights_only=True)
387+
304388
def __download_zip(self, link, location, format):
305389
"""
306390
Download a dataset's zip file from the given URL and save it in the desired location
@@ -408,7 +492,7 @@ def __reformat_yaml(self, location, format):
408492
409493
:return None:
410494
"""
411-
if format in ["yolov5pytorch", "yolov7pytorch"]:
495+
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
412496
with open(location + "/data.yaml") as file:
413497
new_yaml = yaml.safe_load(file)
414498
new_yaml["train"] = location + new_yaml["train"].lstrip("..")

tests/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
def get_version(
66
api_key="test-api-key",
77
project_name="Test Project Name",
8+
public=True,
89
version_number="1",
910
type=TYPE_OBJECT_DETECTION,
1011
workspace_name="Test Workspace Name",
@@ -46,4 +47,5 @@ def get_version(
4647
local=None,
4748
workspace=workspace_name,
4849
project=project_name,
50+
public=public
4951
)

0 commit comments

Comments
 (0)