Skip to content

Commit 0d3d11a

Browse files
committed
adds deploy_model functionality to workspace
1 parent b5ed55b commit 0d3d11a

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

roboflow/core/workspace.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111
from roboflow.adapters import rfapi
1212
from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError
13-
from roboflow.config import API_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
13+
from roboflow.config import API_URL, APP_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
1414
from roboflow.core.project import Project
1515
from roboflow.util import folderparser
1616
from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons
1717
from roboflow.util.image_utils import load_labelmap
18+
from roboflow.util.model_processor import process
1819
from roboflow.util.two_stage_utils import ocr_infer
1920

2021

@@ -565,6 +566,73 @@ def active_learning(
565566
return (
566567
prediction_results if type(raw_data_location) is not np.ndarray else prediction_results[-1]["predictions"]
567568
)
569+
570+
def deploy_model(
571+
self,
572+
model_type: str,
573+
model_path: str,
574+
project_ids: list[str],
575+
model_name: str,
576+
filename: str = "weights/best.pt",
577+
):
578+
"""Uploads provided weights file to Roboflow.
579+
Args:
580+
model_type (str): The type of the model to be deployed.
581+
model_path (str): File path to the model weights to be uploaded.
582+
project_ids (list[str]): List of project IDs to deploy the model to.
583+
filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
584+
"""
585+
586+
if not project_ids:
587+
raise ValueError("At least one project ID must be provided")
588+
589+
# Validate if provided project URLs belong to user's projects
590+
user_projects = set(project.split("/")[-1] for project in self.projects())
591+
for project_id in project_ids:
592+
if project_id not in user_projects:
593+
raise ValueError(f"Project {project_id} is not accessible in this workspace")
594+
595+
zip_file_name = process(model_type, model_path, filename)
596+
597+
if zip_file_name is None:
598+
raise RuntimeError("Failed to process model")
599+
600+
self._upload_zip(model_type, model_path, project_ids, model_name, zip_file_name)
601+
602+
def _upload_zip(
603+
self,
604+
model_type: str,
605+
model_path: str,
606+
project_ids: list[str],
607+
model_name: str,
608+
model_file_name: str,
609+
):
610+
# This endpoint returns a signed URL to upload the model
611+
res = requests.post(
612+
f"{API_URL}/{self.url}/models/prepareUpload?api_key={self.__api_key}&modelType={model_type}&modelName={model_name}&projectIds={','.join(project_ids)}&nocache=true"
613+
)
614+
try:
615+
res.raise_for_status()
616+
except Exception as e:
617+
print(f"An error occured when getting the model deployment URL: {e}")
618+
return
619+
620+
# Upload the model to the signed URL
621+
res = requests.put(
622+
res.json()["url"],
623+
data=open(os.path.join(model_path, model_file_name), "rb"),
624+
)
625+
try:
626+
res.raise_for_status()
627+
628+
for project_id in project_ids:
629+
print(
630+
f"View the status of your deployment for project {project_id} at:"
631+
f" {APP_URL}/{self.url}/{project_id}/models"
632+
)
633+
634+
except Exception as e:
635+
print(f"An error occured when uploading the model: {e}")
568636

569637
def __str__(self):
570638
projects = self.projects()

0 commit comments

Comments
 (0)