Skip to content

Commit 6e52ae8

Browse files
committed
upload rfdetr
1 parent f40a653 commit 6e52ae8

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

roboflow/util/model_processor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _get_processor_function(model_type: str) -> Callable:
2727
"paligemma",
2828
"paligemma2",
2929
"florence-2",
30+
"rfdetr",
3031
]
3132

3233
if not any(supported_model in model_type for supported_model in supported_models):
@@ -60,6 +61,9 @@ def _get_processor_function(model_type: str) -> Callable:
6061
if "yolov12" in model_type:
6162
return _process_yolov12
6263

64+
if "rfdetr" in model_type:
65+
return _process_rfdetr
66+
6367
return _process_yolo
6468

6569

@@ -236,6 +240,36 @@ def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
236240
return zip_file_name
237241

238242

243+
def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
244+
if not os.path.exists(model_path):
245+
raise FileNotFoundError(f"Model path {model_path} does not exist.")
246+
247+
model_files = os.listdir(model_path)
248+
pt_file = next((f for f in model_files if f.endswith(".pt")), None)
249+
250+
if pt_file is None:
251+
raise RuntimeError("No .pt model file found in the provided path")
252+
253+
# Copy the .pt file to weights.pt if not already named weights.pt
254+
if pt_file != "weights.pt":
255+
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))
256+
257+
required_files = ["weights.pt"]
258+
259+
optional_files = ["results.csv", "results.png", "model_artifacts.json"]
260+
261+
zip_file_name = "roboflow_deploy.zip"
262+
with zipfile.ZipFile(os.path.join(model_path, zip_file_name), "w") as zipMe:
263+
for file in required_files:
264+
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
265+
266+
for file in optional_files:
267+
if os.path.exists(os.path.join(model_path, file)):
268+
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
269+
270+
return zip_file_name
271+
272+
239273
def _process_huggingface(
240274
model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
241275
) -> str:

0 commit comments

Comments
 (0)