Skip to content

Commit e99dd1d

Browse files
committed
few bugfixes on roboflow-python
1 parent e6bfab7 commit e99dd1d

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

roboflow/util/model_processor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
114114

115115
print_warn_for_wrong_dependencies_versions([("ultralytics", ">=", "8.3.0")], ask_to_continue=True)
116116

117-
model = torch.load(os.path.join(model_path, filename))
117+
model = torch.load(os.path.join(model_path, filename), map_location="cpu")
118118

119119
if isinstance(model["model"].names, list):
120120
class_names = model["model"].names
@@ -241,14 +241,18 @@ def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
241241

242242

243243
def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
244+
_supported_types = ["rfdetr-base", "rfdetr-large"]
245+
if model_type not in _supported_types:
246+
raise ValueError(f"Model type {model_type} not supported. Supported types are {_supported_types}")
247+
244248
if not os.path.exists(model_path):
245249
raise FileNotFoundError(f"Model path {model_path} does not exist.")
246250

247251
model_files = os.listdir(model_path)
248-
pt_file = next((f for f in model_files if f.endswith(".pt")), None)
252+
pt_file = next((f for f in model_files if f.endswith(".pt") or f.endswith(".pth")), None)
249253

250254
if pt_file is None:
251-
raise RuntimeError("No .pt model file found in the provided path")
255+
raise RuntimeError("No .pt or .pth model file found in the provided path")
252256

253257
class_names = get_classnames_txt_for_rfdetr(model_path, pt_file)
254258

@@ -258,7 +262,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
258262

259263
required_files = ["weights.pt"]
260264

261-
optional_files = ["results.csv", "results.png", "model_artifacts.json"]
265+
optional_files = ["results.csv", "results.png", "model_artifacts.json", "class_names.txt"]
262266

263267
zip_file_name = "roboflow_deploy.zip"
264268
with zipfile.ZipFile(os.path.join(model_path, zip_file_name), "w") as zipMe:
@@ -272,11 +276,13 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
272276
return zip_file_name
273277

274278

275-
def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str) -> list[str]:
279+
def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
276280
class_names_path = os.path.join(model_path, "class_names.txt")
277281
if os.path.exists(class_names_path):
278282
return class_names_path
279283

284+
import torch
285+
280286
model = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False)
281287
args = vars(model["args"])
282288
if "class_names" in args:

0 commit comments

Comments
 (0)