Skip to content

Commit 1946a96

Browse files
committed
Yolonas upload
1 parent bd5f01a commit 1946a96

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

roboflow/core/version.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import json
33
import os
4+
import shutil
45
import sys
56
import time
67
import zipfile
@@ -425,10 +426,14 @@ def deploy(self, model_type: str, model_path: str) -> None:
425426
model_path (str): File path to model weights to be uploaded
426427
"""
427428

428-
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9"]
429+
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"]
429430

430431
if not any(supported_model in model_type for supported_model in supported_models):
431432
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))
433+
434+
if "yolonas" in model_type:
435+
self.deploy_yolonas(model_path)
436+
return
432437

433438
if "yolov8" in model_type:
434439
try:
@@ -535,6 +540,61 @@ def deploy(self, model_type: str, model_path: str) -> None:
535540
if file in ["model_artifacts.json", "state_dict.pt"]:
536541
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
537542

543+
self.upload_zip(model_type, model_path)
544+
545+
def deploy_yolonas(self, model_type: str, model_path: str) -> None:
546+
try:
547+
import torch
548+
except ImportError:
549+
raise (
550+
"The torch python package is required to deploy yolonas models."
551+
" Please install it with `pip install torch`"
552+
)
553+
554+
model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
555+
class_names = model["processing_params"]["class_names"]
556+
557+
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
558+
opts = yaml.safe_load(stream)
559+
560+
model_artifacts = {
561+
"names": class_names,
562+
"nc": len(class_names),
563+
"args": {
564+
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
565+
"batch": opts["batch_size"],
566+
},
567+
"model_type": model_type,
568+
}
569+
570+
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
571+
json.dump(model_artifacts, fp)
572+
573+
shutil.copy(os.path.join(model_path, "weights/best.pt"),
574+
os.path.join(model_path, "state_dict.pt"))
575+
576+
lista_files = [
577+
"results.json",
578+
"results.png",
579+
"model_artifacts.json",
580+
"state_dict.pt",
581+
]
582+
583+
with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe:
584+
for file in lista_files:
585+
if os.path.exists(os.path.join(model_path, file)):
586+
zipMe.write(
587+
os.path.join(model_path, file),
588+
arcname=file,
589+
compress_type=zipfile.ZIP_DEFLATED,
590+
)
591+
else:
592+
if file in ["model_artifacts.json", "best.pt"]:
593+
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
594+
595+
self.upload_zip(model_type, model_path)
596+
597+
def upload_zip(self, model_type: str, model_path: str):
538598
res = requests.get(
539599
f"{API_URL}/{self.workspace}/{self.project}/{self.version}"
540600
f"/uploadModel?api_key={self.__api_key}&modelType={model_type}&nocache=true"

0 commit comments

Comments
 (0)