@@ -312,16 +312,18 @@ def deploy(self, model_type: str, model_path: str) -> None:
312312 model_path (str): File path to model weights to be uploaded
313313 """
314314
315- supported_models = ["yolov8 " , "yolov5 " , "yolov7-seg " ]
315+ supported_models = ["yolov5 " , "yolov7 " , "yolov8 " ]
316316
317- if model_type not in supported_models :
317+ if not any (
318+ supported_model in model_type for supported_model in supported_models
319+ ):
318320 raise (
319321 ValueError (
320322 f"Model type { model_type } not supported. Supported models are { supported_models } "
321323 )
322324 )
323325
324- if model_type == "yolov8" :
326+ if "yolov8" in model_type :
325327 try :
326328 import torch
327329 import ultralytics
@@ -335,7 +337,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
335337 [("ultralytics" , "<=" , "8.0.20" )]
336338 )
337339
338- elif model_type in [ "yolov5" , "yolov7-seg" ] :
340+ elif "yolov5" in model_type or "yolov7" in model_type :
339341 try :
340342 import torch
341343 except ImportError as e :
@@ -345,22 +347,31 @@ def deploy(self, model_type: str, model_path: str) -> None:
345347
346348 model = torch .load (os .path .join (model_path , "weights/best.pt" ))
347349
348- class_names = []
349- for i , val in enumerate (model ["model" ].names ):
350- class_names .append ((val , model ["model" ].names [val ]))
350+ if isinstance (model ["model" ].names , list ):
351+ class_names = model ["model" ].names
352+ else :
353+ class_names = []
354+ for i , val in enumerate (model ["model" ].names ):
355+ class_names .append ((val , model ["model" ].names [val ]))
351356 class_names .sort (key = lambda x : x [0 ])
352357 class_names = [x [1 ] for x in class_names ]
353358
354- if model_type == "yolov8" :
359+ if "yolov8" in model_type :
355360 # try except for backwards compatibility with older versions of ultralytics
361+ if "-cls" in model_type :
362+ nc = model ["model" ].yaml ["nc" ]
363+ args = model ["train_args" ]
364+ else :
365+ nc = model ["model" ].nc
366+ args = model ["model" ].args
356367 try :
357368 model_artifacts = {
358369 "names" : class_names ,
359370 "yaml" : model ["model" ].yaml ,
360- "nc" : model [ "model" ]. nc ,
371+ "nc" : nc ,
361372 "args" : {
362373 k : val
363- for k , val in model [ "model" ]. args .items ()
374+ for k , val in args .items ()
364375 if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
365376 },
366377 "ultralytics_version" : ultralytics .__version__ ,
@@ -370,33 +381,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
370381 model_artifacts = {
371382 "names" : class_names ,
372383 "yaml" : model ["model" ].yaml ,
373- "nc" : model [ "model" ]. nc ,
384+ "nc" : nc ,
374385 "args" : {
375386 k : val
376- for k , val in model [ "model" ]. args .__dict__ .items ()
387+ for k , val in args .__dict__ .items ()
377388 if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
378389 },
379390 "ultralytics_version" : ultralytics .__version__ ,
380391 "model_type" : model_type ,
381392 }
382- elif model_type in [ "yolov5" , "yolov7-seg" ] :
393+ elif "yolov5" in model_type or "yolov7" in model_type :
383394 # parse from yaml for yolov5
384395
385396 with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
386397 opts = yaml .safe_load (stream )
387398
388399 model_artifacts = {
389400 "names" : class_names ,
390- "yaml" : model ["model" ].yaml ,
391401 "nc" : model ["model" ].nc ,
392- "args" : {"imgsz" : opts ["imgsz" ], "batch" : opts ["batch_size" ]},
402+ "args" : {
403+ "imgsz" : opts ["imgsz" ] if "imgsz" in opts else opts ["img_size" ],
404+ "batch" : opts ["batch_size" ],
405+ },
393406 "model_type" : model_type ,
394407 }
408+ if hasattr (model ["model" ], "yaml" ):
409+ model_artifacts ["yaml" ] = model ["model" ].yaml
395410
396- with open (model_path + "model_artifacts.json" , "w" ) as fp :
411+ with open (os . path . join ( model_path , "model_artifacts.json" ) , "w" ) as fp :
397412 json .dump (model_artifacts , fp )
398413
399- torch .save (model ["model" ].state_dict (), model_path + "state_dict.pt" )
414+ torch .save (
415+ model ["model" ].state_dict (), os .path .join (model_path , "state_dict.pt" )
416+ )
400417
401418 lista_files = [
402419 "results.csv" ,
@@ -405,11 +422,13 @@ def deploy(self, model_type: str, model_path: str) -> None:
405422 "state_dict.pt" ,
406423 ]
407424
408- with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
425+ with zipfile .ZipFile (
426+ os .path .join (model_path , "roboflow_deploy.zip" ), "w"
427+ ) as zipMe :
409428 for file in lista_files :
410- if os .path .exists (model_path + file ):
429+ if os .path .exists (os . path . join ( model_path , file ) ):
411430 zipMe .write (
412- model_path + file ,
431+ os . path . join ( model_path , file ) ,
413432 arcname = file ,
414433 compress_type = zipfile .ZIP_DEFLATED ,
415434 )
@@ -422,7 +441,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
422441 )
423442
424443 res = requests .get (
425- f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
444+ f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } &modelType= { model_type } "
426445 )
427446 try :
428447 if res .status_code == 429 :
@@ -437,7 +456,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
437456
438457 res = requests .put (
439458 res .json ()["url" ],
440- data = open (os .path .join (model_path + "roboflow_deploy.zip" ), "rb" ),
459+ data = open (os .path .join (model_path , "roboflow_deploy.zip" ), "rb" ),
441460 )
442461 try :
443462 res .raise_for_status ()
@@ -590,7 +609,8 @@ def callback(content: dict) -> dict:
590609 pass
591610 return content
592611
593- amend_data_yaml (path = data_path , callback = callback )
612+ if os .path .exists (data_path ):
613+ amend_data_yaml (path = data_path , callback = callback )
594614
595615 def __str__ (self ):
596616 """string representation of version object."""
0 commit comments