Skip to content

Commit 3e85436

Browse files
authored
Merge pull request #238 from roboflow/feature/yolonas-upload
Feature/yolonas upload
2 parents 33ae3b0 + 4d2d512 commit 3e85436

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

roboflow/core/version.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import json
33
import os
4+
import shutil
45
import sys
56
import time
67
import zipfile
@@ -425,11 +426,15 @@ def deploy(self, model_type: str, model_path: str) -> None:
425426
model_path (str): File path to model weights to be uploaded
426427
"""
427428

428-
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9"]
429+
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"]
429430

430431
if not any(supported_model in model_type for supported_model in supported_models):
431432
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))
432433

434+
if "yolonas" in model_type:
435+
self.deploy_yolonas(model_type, model_path)
436+
return
437+
433438
if "yolov8" in model_type:
434439
try:
435440
import torch
@@ -516,15 +521,15 @@ def deploy(self, model_type: str, model_path: str) -> None:
516521

517522
torch.save(model["model"].state_dict(), os.path.join(model_path, "state_dict.pt"))
518523

519-
lista_files = [
524+
list_files = [
520525
"results.csv",
521526
"results.png",
522527
"model_artifacts.json",
523528
"state_dict.pt",
524529
]
525530

526531
with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe:
527-
for file in lista_files:
532+
for file in list_files:
528533
if os.path.exists(os.path.join(model_path, file)):
529534
zipMe.write(
530535
os.path.join(model_path, file),
@@ -535,6 +540,74 @@ def deploy(self, model_type: str, model_path: str) -> None:
535540
if file in ["model_artifacts.json", "state_dict.pt"]:
536541
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
537542

543+
self.upload_zip(model_type, model_path)
544+
545+
def deploy_yolonas(self, model_type: str, model_path: str) -> None:
546+
try:
547+
import torch
548+
except ImportError:
549+
raise (
550+
"The torch python package is required to deploy yolonas models."
551+
" Please install it with `pip install torch`"
552+
)
553+
554+
model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
555+
class_names = model["processing_params"]["class_names"]
556+
557+
opt_path = os.path.join(model_path, "opt.yaml")
558+
if not os.path.exists(opt_path):
559+
raise RuntimeError(
560+
f"You must create an opt.yaml file at {os.path.join(model_path, '')} of the format:\n"
561+
f"imgsz: <resolution of model>\n"
562+
f"batch_size: <batch size of inference model>\n"
563+
f"architecture: <one of [yolo_nas_s, yolo_nas_m, yolo_nas_l]."
564+
f"s, m, l refer to small, medium, large architecture sizes, respectively>\n"
565+
)
566+
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
567+
opts = yaml.safe_load(stream)
568+
required_keys = ["imgsz", "batch_size", "architecture"]
569+
for key in required_keys:
570+
if key not in opts:
571+
raise RuntimeError(f"{opt_path} lacks required key {key}. Required keys: {required_keys}")
572+
573+
model_artifacts = {
574+
"names": class_names,
575+
"nc": len(class_names),
576+
"args": {
577+
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
578+
"batch": opts["batch_size"],
579+
"architecture": opts["architecture"],
580+
},
581+
"model_type": model_type,
582+
}
583+
584+
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
585+
json.dump(model_artifacts, fp)
586+
587+
shutil.copy(os.path.join(model_path, "weights/best.pt"), os.path.join(model_path, "state_dict.pt"))
588+
589+
list_files = [
590+
"results.json",
591+
"results.png",
592+
"model_artifacts.json",
593+
"state_dict.pt",
594+
]
595+
596+
with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe:
597+
for file in list_files:
598+
if os.path.exists(os.path.join(model_path, file)):
599+
zipMe.write(
600+
os.path.join(model_path, file),
601+
arcname=file,
602+
compress_type=zipfile.ZIP_DEFLATED,
603+
)
604+
else:
605+
if file in ["model_artifacts.json", "best.pt"]:
606+
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
607+
608+
self.upload_zip(model_type, model_path)
609+
610+
def upload_zip(self, model_type: str, model_path: str):
538611
res = requests.get(
539612
f"{API_URL}/{self.workspace}/{self.project}/{self.version}"
540613
f"/uploadModel?api_key={self.__api_key}&modelType={model_type}&nocache=true"

0 commit comments

Comments
 (0)