Skip to content

Commit 57b7aa7

Browse files
authored
Merge pull request #111 from roboflow/feature/yolov7-seg-upload
YOLOv7-Seg Deploy
2 parents bf49a9b + 6d202cb commit 57b7aa7

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

roboflow/core/version.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
312312
model_path (str): File path to model weights to be uploaded
313313
"""
314314

315-
supported_models = ["yolov8", "yolov5"]
315+
supported_models = ["yolov8", "yolov5", "yolov7-seg"]
316316

317317
if model_type not in supported_models:
318318
raise (
@@ -335,7 +335,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
335335
[("ultralytics", "<=", "8.0.20")]
336336
)
337337

338-
elif model_type == "yolov5":
338+
elif model_type in ["yolov5", "yolov7-seg"]:
339339
try:
340340
import torch
341341
except ImportError as e:
@@ -379,7 +379,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
379379
"ultralytics_version": ultralytics.__version__,
380380
"model_type": model_type,
381381
}
382-
elif model_type == "yolov5":
382+
elif model_type in ["yolov5", "yolov7-seg"]:
383383
# parse from yaml for yolov5
384384

385385
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
@@ -407,11 +407,19 @@ def deploy(self, model_type: str, model_path: str) -> None:
407407

408408
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
409409
for file in lista_files:
410-
zipMe.write(
411-
model_path + file,
412-
arcname=file,
413-
compress_type=zipfile.ZIP_DEFLATED,
414-
)
410+
if os.path.exists(model_path + file):
411+
zipMe.write(
412+
model_path + file,
413+
arcname=file,
414+
compress_type=zipfile.ZIP_DEFLATED,
415+
)
416+
else:
417+
if file in ["model_artifacts.json", "state_dict.pt"]:
418+
raise (
419+
ValueError(
420+
f"File {file} not found. Please make sure to provide a valid model path."
421+
)
422+
)
415423

416424
res = requests.get(
417425
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"

0 commit comments

Comments
 (0)