Skip to content

Commit 5b46369

Browse files
committed
deploy model in workspace
1 parent f220764 commit 5b46369

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

roboflow/core/workspace.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import concurrent.futures
24
import glob
35
import json
@@ -10,11 +12,12 @@
1012

1113
from roboflow.adapters import rfapi
1214
from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError
13-
from roboflow.config import API_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
15+
from roboflow.config import API_URL, APP_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
1416
from roboflow.core.project import Project
1517
from roboflow.util import folderparser
1618
from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons
1719
from roboflow.util.image_utils import load_labelmap
20+
from roboflow.util.model_processor import process
1821
from roboflow.util.two_stage_utils import ocr_infer
1922

2023

@@ -566,6 +569,64 @@ def active_learning(
566569
prediction_results if type(raw_data_location) is not np.ndarray else prediction_results[-1]["predictions"]
567570
)
568571

572+
def deploy_model(
573+
self,
574+
model_type: str,
575+
model_path: str,
576+
project_urls: list[str],
577+
filename: str = "weights/best.pt",
578+
):
579+
"""Uploads provided weights file to Roboflow.
580+
581+
Args:
582+
model_type (str): The type of the model to be deployed.
583+
model_path (str): File path to the model weights to be uploaded.
584+
project_urls (list[str]): List of project URLs to deploy the model to.
585+
filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
586+
"""
587+
588+
if not project_urls:
589+
raise ValueError("At least one project URL must be provided")
590+
591+
# Validate if provided project URLs belong to user's projects
592+
user_projects = set(self.projects())
593+
for project_url in project_urls:
594+
project_id = project_url.split("/")[-1]
595+
if project_id not in user_projects:
596+
raise ValueError(f"Project {project_url} is not accessible in this workspace")
597+
598+
zip_file_name = process(model_type, model_path, filename)
599+
600+
if zip_file_name is None:
601+
raise RuntimeError("Failed to process model")
602+
603+
self._upload_zip(model_type, model_path, zip_file_name, project_urls)
604+
605+
def _upload_zip(self, model_type: str, model_path: str, project_urls: list[str], model_file_name: str):
606+
# TODO: Need to create this endpoint
607+
res = requests.post(
608+
f"{API_URL}/{self.workspace}/uploadModel?api_key={self.__api_key}&modelType={model_type}&project_urls={','.join(project_urls)}&nocache=true"
609+
)
610+
try:
611+
res.raise_for_status()
612+
except Exception as e:
613+
print(f"An error occured when getting the model upload URL: {e}")
614+
return
615+
616+
# TODO: Need to check why we use that
617+
res = requests.put(
618+
res.json()["url"],
619+
data=open(os.path.join(model_path, model_file_name), "rb"),
620+
)
621+
try:
622+
res.raise_for_status()
623+
624+
# TODO: Need to check this URL
625+
print("View the status of your deployment at:" f" {APP_URL}/{self.workspace}/models")
626+
627+
except Exception as e:
628+
print(f"An error occured when uploading the model: {e}")
629+
569630
def __str__(self):
570631
projects = self.projects()
571632
json_value = {"name": self.name, "url": self.url, "projects": projects}

0 commit comments

Comments
 (0)