@@ -114,7 +114,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
114
114
115
115
print_warn_for_wrong_dependencies_versions ([("ultralytics" , ">=" , "8.3.0" )], ask_to_continue = True )
116
116
117
- model = torch .load (os .path .join (model_path , filename ))
117
+ model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
118
118
119
119
if isinstance (model ["model" ].names , list ):
120
120
class_names = model ["model" ].names
@@ -241,14 +241,18 @@ def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
241
241
242
242
243
243
def _process_rfdetr (model_type : str , model_path : str , filename : str ) -> str :
244
+ _supported_types = ["rfdetr-base" , "rfdetr-large" ]
245
+ if model_type not in _supported_types :
246
+ raise ValueError (f"Model type { model_type } not supported. Supported types are { _supported_types } " )
247
+
244
248
if not os .path .exists (model_path ):
245
249
raise FileNotFoundError (f"Model path { model_path } does not exist." )
246
250
247
251
model_files = os .listdir (model_path )
248
- pt_file = next ((f for f in model_files if f .endswith (".pt" )), None )
252
+ pt_file = next ((f for f in model_files if f .endswith (".pt" ) or f . endswith ( ".pth" ) ), None )
249
253
250
254
if pt_file is None :
251
- raise RuntimeError ("No .pt model file found in the provided path" )
255
+ raise RuntimeError ("No .pt or .pth model file found in the provided path" )
252
256
253
257
class_names = get_classnames_txt_for_rfdetr (model_path , pt_file )
254
258
@@ -258,7 +262,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
258
262
259
263
required_files = ["weights.pt" ]
260
264
261
- optional_files = ["results.csv" , "results.png" , "model_artifacts.json" ]
265
+ optional_files = ["results.csv" , "results.png" , "model_artifacts.json" , "class_names.txt" ]
262
266
263
267
zip_file_name = "roboflow_deploy.zip"
264
268
with zipfile .ZipFile (os .path .join (model_path , zip_file_name ), "w" ) as zipMe :
@@ -272,11 +276,13 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
272
276
return zip_file_name
273
277
274
278
275
- def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ) -> list [ str ] :
279
+ def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ):
276
280
class_names_path = os .path .join (model_path , "class_names.txt" )
277
281
if os .path .exists (class_names_path ):
278
282
return class_names_path
279
283
284
+ import torch
285
+
280
286
model = torch .load (os .path .join (model_path , pt_file ), map_location = "cpu" , weights_only = False )
281
287
args = vars (model ["args" ])
282
288
if "class_names" in args :
0 commit comments