@@ -114,7 +114,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
114114
115115 print_warn_for_wrong_dependencies_versions ([("ultralytics" , ">=" , "8.3.0" )], ask_to_continue = True )
116116
117- model = torch .load (os .path .join (model_path , filename ))
117+ model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
118118
119119 if isinstance (model ["model" ].names , list ):
120120 class_names = model ["model" ].names
@@ -241,14 +241,18 @@ def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
241241
242242
243243def _process_rfdetr (model_type : str , model_path : str , filename : str ) -> str :
244+ _supported_types = ["rfdetr-base" , "rfdetr-large" ]
245+ if model_type not in _supported_types :
246+ raise ValueError (f"Model type { model_type } not supported. Supported types are { _supported_types } " )
247+
244248 if not os .path .exists (model_path ):
245249 raise FileNotFoundError (f"Model path { model_path } does not exist." )
246250
247251 model_files = os .listdir (model_path )
248- pt_file = next ((f for f in model_files if f .endswith (".pt" )), None )
252+ pt_file = next ((f for f in model_files if f .endswith (".pt" ) or f . endswith ( ".pth" ) ), None )
249253
250254 if pt_file is None :
251- raise RuntimeError ("No .pt model file found in the provided path" )
255+ raise RuntimeError ("No .pt or .pth model file found in the provided path" )
252256
253257 class_names = get_classnames_txt_for_rfdetr (model_path , pt_file )
254258
@@ -258,7 +262,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
258262
259263 required_files = ["weights.pt" ]
260264
261- optional_files = ["results.csv" , "results.png" , "model_artifacts.json" ]
265+ optional_files = ["results.csv" , "results.png" , "model_artifacts.json" , "class_names.txt" ]
262266
263267 zip_file_name = "roboflow_deploy.zip"
264268 with zipfile .ZipFile (os .path .join (model_path , zip_file_name ), "w" ) as zipMe :
@@ -272,11 +276,13 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
272276 return zip_file_name
273277
274278
275- def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ) -> list [ str ] :
279+ def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ):
276280 class_names_path = os .path .join (model_path , "class_names.txt" )
277281 if os .path .exists (class_names_path ):
278282 return class_names_path
279283
284+ import torch
285+
280286 model = torch .load (os .path .join (model_path , pt_file ), map_location = "cpu" , weights_only = False )
281287 args = vars (model ["args" ])
282288 if "class_names" in args :
0 commit comments