Skip to content

Commit 996da6d

Browse files
committed
dry it up between v5 v8 parse
1 parent d1ee455 commit 996da6d

File tree

1 file changed

+74
-125
lines changed

1 file changed

+74
-125
lines changed

roboflow/core/version.py

Lines changed: 74 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
299299
f"Model type {model_type} not supported. Supported models are {supported_models}"
300300
)
301301
)
302-
302+
303303
if model_type == "yolov8":
304304
try:
305305
import torch
@@ -308,115 +308,56 @@ def deploy(self, model_type: str, model_path: str) -> None:
308308
raise (
309309
"The ultralytics python package is required to deploy yolov8 models. Please install it with `pip install ultralytics`"
310310
)
311-
312-
# add logic to save torch state dict safely
313-
if model_type == "yolov8":
314-
model = torch.load(model_path + "weights/best.pt")
315-
316-
class_names = []
317-
for i, val in enumerate(model["model"].names):
318-
class_names.append((val, model["model"].names[val]))
319-
class_names.sort(key=lambda x: x[0])
320-
class_names = [x[1] for x in class_names]
321-
322-
try:
323-
model_artifacts = {
324-
"names": class_names,
325-
"yaml": model["model"].yaml,
326-
"nc": model["model"].nc,
327-
"args": {
328-
k: val
329-
for k, val in model["model"].args.items()
330-
if ((k == "model") or (k == "imgsz") or (k == "batch"))
331-
},
332-
"ultralytics_version": ultralytics.__version__,
333-
"model_type": model_type,
334-
}
335-
except:
336-
model_artifacts = {
337-
"names": class_names,
338-
"yaml": model["model"].yaml,
339-
"nc": model["model"].nc,
340-
"args": {
341-
k: val
342-
for k, val in model["model"].args.__dict__.items()
343-
if ((k == "model") or (k == "imgsz") or (k == "batch"))
344-
},
345-
"ultralytics_version": ultralytics.__version__,
346-
"model_type": model_type,
347-
}
348-
349-
with open(model_path + "model_artifacts.json", "w") as fp:
350-
json.dump(model_artifacts, fp)
351-
352-
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
353-
354-
lista_files = [
355-
"results.csv",
356-
"results.png",
357-
"model_artifacts.json",
358-
"state_dict.pt",
359-
]
360-
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
361-
for file in lista_files:
362-
zipMe.write(
363-
model_path + file,
364-
arcname=file,
365-
compress_type=zipfile.ZIP_DEFLATED,
366-
)
367-
368-
res = requests.get(
369-
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
370-
)
371-
try:
372-
if res.status_code == 429:
373-
raise RuntimeError(
374-
f"This version already has a trained model. Please generate and train a new version in order to upload model to Roboflow."
375-
)
376-
else:
377-
res.raise_for_status()
378-
except Exception as e:
379-
print(f"An error occured when getting the model upload URL: {e}")
380-
return
381-
382-
res = requests.put(
383-
res.json()["url"], data=open(model_path + "roboflow_deploy.zip", "rb")
384-
)
385-
try:
386-
res.raise_for_status()
387-
388-
if self.public:
389-
print(
390-
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/deploy/{self.version}"
391-
)
392-
print(
393-
f"Share your model with the world at: {UNIVERSE_URL}/{self.workspace}/{self.project}/model/{self.version}"
394-
)
395-
else:
396-
print(
397-
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/deploy/{self.version}"
398-
)
399-
400-
except Exception as e:
401-
print(f"An error occured when uploading the model: {e}")
402311
elif model_type == "yolov5":
403312
try:
404313
import torch
405-
import yaml
406314
except ImportError as e:
407-
raise ("PyTorch must be installed to use this feature.")
315+
raise (
316+
"The torch python package is required to deploy yolov5 models. Please install it with `pip install torch`"
317+
)
408318

409-
model = torch.load(os.path.join(model_path + "weights/best.pt"))
319+
model = torch.load(model_path + "weights/best.pt")
410320

411-
class_names = []
412-
for i, val in enumerate(model["model"].names):
413-
class_names.append((val, model["model"].names[val]))
414-
class_names.sort(key=lambda x: x[0])
415-
class_names = [x[1] for x in class_names]
321+
class_names = []
322+
for i, val in enumerate(model["model"].names):
323+
class_names.append((val, model["model"].names[val]))
324+
class_names.sort(key=lambda x: x[0])
325+
class_names = [x[1] for x in class_names]
416326

327+
if model_type == "yolov8":
328+
#try except for backwards compatibility with older versions of ultralytics
329+
try:
330+
model_artifacts = {
331+
"names": class_names,
332+
"yaml": model["model"].yaml,
333+
"nc": model["model"].nc,
334+
"args": {
335+
k: val
336+
for k, val in model["model"].args.items()
337+
if ((k == "model") or (k == "imgsz") or (k == "batch"))
338+
},
339+
"ultralytics_version": ultralytics.__version__,
340+
"model_type": model_type,
341+
}
342+
except:
343+
model_artifacts = {
344+
"names": class_names,
345+
"yaml": model["model"].yaml,
346+
"nc": model["model"].nc,
347+
"args": {
348+
k: val
349+
for k, val in model["model"].args.__dict__.items()
350+
if ((k == "model") or (k == "imgsz") or (k == "batch"))
351+
},
352+
"ultralytics_version": ultralytics.__version__,
353+
"model_type": model_type,
354+
}
355+
elif model_type == "yolov5":
356+
#parse from yaml for yolov5
357+
417358
with open(os.path.join(model_path + "opt.yaml"), "r") as stream:
418359
opts = yaml.safe_load(stream)
419-
360+
420361
model_artifacts = {
421362
"names": class_names,
422363
"yaml": model["model"].yaml,
@@ -425,31 +366,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
425366
"model_type": model_type,
426367
}
427368

428-
with open(os.path.join(model_path + "model_artifacts.json"), "w") as fp:
429-
json.dump(model_artifacts, fp)
430-
431-
torch.save(
432-
model["model"].state_dict(), os.path.join(model_path + "state_dict.pt")
433-
)
434-
435-
lista_files = [
436-
"results.csv",
437-
"results.png",
438-
"model_artifacts.json",
439-
"state_dict.pt",
440-
]
441-
442-
with zipfile.ZipFile(
443-
os.path.join(model_path + "roboflow_deploy.zip"), "w"
444-
) as zipMe:
445-
for file in lista_files:
446-
zipMe.write(
447-
os.path.join(model_path + file),
448-
arcname=file,
449-
compress_type=zipfile.ZIP_DEFLATED,
450-
)
369+
with open(model_path + "model_artifacts.json", "w") as fp:
370+
json.dump(model_artifacts, fp)
371+
372+
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
373+
374+
lista_files = [
375+
"results.csv",
376+
"results.png",
377+
"model_artifacts.json",
378+
"state_dict.pt",
379+
]
380+
381+
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
382+
for file in lista_files:
383+
zipMe.write(
384+
model_path + file,
385+
arcname=file,
386+
compress_type=zipfile.ZIP_DEFLATED,
387+
)
451388

452-
# torch.load("state_dict.pt", weights_only=True)
389+
res = requests.get(
390+
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
391+
)
392+
try:
393+
if res.status_code == 429:
394+
raise RuntimeError(
395+
f"This version already has a trained model. Please generate and train a new version in order to upload model to Roboflow."
396+
)
397+
else:
398+
res.raise_for_status()
399+
except Exception as e:
400+
print(f"An error occured when getting the model upload URL: {e}")
401+
return
453402

454403
def __download_zip(self, link, location, format):
455404
"""

0 commit comments

Comments
 (0)