1212
1313from roboflow .config import (
1414 API_URL ,
15+ APP_URL ,
1516 DEMO_KEYS ,
1617 TYPE_CLASSICATION ,
1718 TYPE_INSTANCE_SEGMENTATION ,
1819 TYPE_OBJECT_DETECTION ,
1920 TYPE_SEMANTIC_SEGMENTATION ,
21+ UNIVERSE_URL ,
2022)
2123from roboflow .core .dataset import Dataset
2224from roboflow .models .classification import ClassificationModel
@@ -39,6 +41,7 @@ def __init__(
3941 local ,
4042 workspace ,
4143 project ,
44+ public ,
4245 ):
4346 if api_key in DEMO_KEYS :
4447 if api_key == "coco-128-sample" :
@@ -70,6 +73,7 @@ def __init__(
7073 self .model_format = model_format
7174 self .workspace = workspace
7275 self .project = project
76+ self .public = public
7377 if "exports" in version_dict .keys ():
7478 self .exports = version_dict ["exports" ]
7579 else :
@@ -136,12 +140,13 @@ def __wait_if_generating(self, recurse=False):
136140 sys .stdout .flush ()
137141 return
138142
139- def download (self , model_format = None , location = None ):
143+ def download (self , model_format = None , location = None , overwrite : bool = True ):
140144 """
141145 Download and extract a ZIP of a version's dataset in a given format
142146
143147 :param model_format: A format to use for downloading
144148 :param location: An optional path for saving the file
149+ :param overwrite: An optional flag to prevent dataset overwrite when dataset is already downloaded
145150
146151 :return: Dataset
147152 """
@@ -157,6 +162,10 @@ def download(self, model_format=None, location=None):
157162
158163 if location is None :
159164 location = self .__get_download_location ()
165+ if os .path .exists (location ) and not overwrite :
166+ return Dataset (
167+ self .name , self .version , model_format , os .path .abspath (location )
168+ )
160169
161170 if self .__api_key == "coco-128-sample" :
162171 link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
@@ -275,12 +284,70 @@ def train(self, speed=None, checkpoint=None) -> bool:
275284
276285 return True
277286
278- def upload_model (self , model_path : str ) -> None :
287+ def deploy (self , model_type : str , model_path : str ) -> None :
279288 """Uploads provided weights file to Roboflow
280289
281290 Args:
282291 model_path (str): File path to model weights to be uploaded
283292 """
293+
294+ supported_models = ["yolov8" ]
295+
296+ if model_type not in supported_models :
297+ raise (
298+ ValueError (
299+ f"Model type { model_type } not supported. Supported models are { supported_models } "
300+ )
301+ )
302+
303+ try :
304+ import torch
305+ import ultralytics
306+ except ImportError as e :
307+ raise (
308+ "The ultralytics python package is required to deploy yolov8 models. Please install it with `pip install ultralytics`"
309+ )
310+
311+ # add logic to save torch state dict safely
312+ if model_type == "yolov8" :
313+ model = torch .load (model_path + "weights/best.pt" )
314+
315+ class_names = []
316+ for i , val in enumerate (model ["model" ].names ):
317+ class_names .append ((val , model ["model" ].names [val ]))
318+ class_names .sort (key = lambda x : x [0 ])
319+ class_names = [x [1 ] for x in class_names ]
320+
321+ model_artifacts = {
322+ "names" : class_names ,
323+ "yaml" : model ["model" ].yaml ,
324+ "nc" : model ["model" ].nc ,
325+ "args" : {
326+ k : val for k , val in model ["model" ].args .items () if k != "hydra"
327+ },
328+ "ultralytics_version" : ultralytics .__version__ ,
329+ "model_type" : model_type ,
330+ }
331+
332+ with open (model_path + "model_artifacts.json" , "w" ) as fp :
333+ json .dump (model_artifacts , fp )
334+
335+ torch .save (model ["model" ].state_dict (), model_path + "state_dict.pt" )
336+
337+ lista_files = [
338+ "results.csv" ,
339+ "results.png" ,
340+ "model_artifacts.json" ,
341+ "state_dict.pt" ,
342+ ]
343+ with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
344+ for file in lista_files :
345+ zipMe .write (
346+ model_path + file ,
347+ arcname = file ,
348+ compress_type = zipfile .ZIP_DEFLATED ,
349+ )
350+
284351 res = requests .get (
285352 f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
286353 )
@@ -294,13 +361,30 @@ def upload_model(self, model_path: str) -> None:
294361 except Exception as e :
295362 print (f"An error occured when getting the model upload URL: { e } " )
296363 return
297- res = requests .put (res .json ()["url" ], data = open (model_path , "rb" ))
364+
365+ res = requests .put (
366+ res .json ()["url" ], data = open (model_path + "roboflow_deploy.zip" , "rb" )
367+ )
298368 try :
299369 res .raise_for_status ()
300- print ("Model uploaded" )
370+
371+ if self .public :
372+ print (
373+ f"View the status of your deployment at: { APP_URL } /{ self .workspace } /{ self .project } /deploy/{ self .version } "
374+ )
375+ print (
376+ f"Share your model with the world at: { UNIVERSE_URL } /{ self .workspace } /{ self .project } /model/{ self .version } "
377+ )
378+ else :
379+ print (
380+ f"View the status of your deployment at: { APP_URL } /{ self .workspace } /{ self .project } /deploy/{ self .version } "
381+ )
382+
301383 except Exception as e :
302384 print (f"An error occured when uploading the model: { e } " )
303385
386+ # torch.load("state_dict.pt", weights_only=True)
387+
304388 def __download_zip (self , link , location , format ):
305389 """
306390 Download a dataset's zip file from the given URL and save it in the desired location
@@ -408,7 +492,7 @@ def __reformat_yaml(self, location, format):
408492
409493 :return None:
410494 """
411- if format in ["yolov5pytorch" , "yolov7pytorch" ]:
495+ if format in ["yolov5pytorch" , "yolov7pytorch" , "yolov8" ]:
412496 with open (location + "/data.yaml" ) as file :
413497 new_yaml = yaml .safe_load (file )
414498 new_yaml ["train" ] = location + new_yaml ["train" ].lstrip (".." )
0 commit comments