Skip to content

Commit 9a19ac2

Browse files
committed
add background class
1 parent e99dd1d commit 9a19ac2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

roboflow/util/model_processor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
279279
def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
280280
class_names_path = os.path.join(model_path, "class_names.txt")
281281
if os.path.exists(class_names_path):
282+
maybe_prepend_dummy_class(class_names_path)
282283
return class_names_path
283284

284285
import torch
@@ -289,6 +290,7 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
289290
with open(class_names_path, "w") as f:
290291
for class_name in args["class_names"]:
291292
f.write(class_name + "\n")
293+
maybe_prepend_dummy_class(class_names_path)
292294
return class_names_path
293295

294296
raise FileNotFoundError(
@@ -298,6 +300,16 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
298300
f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n"
299301
)
300302

303+
def maybe_prepend_dummy_class(class_name_file: str):
304+
with open(class_name_file, "r") as f:
305+
class_names = f.readlines()
306+
307+
dummy_class = "background_class83422\n"
308+
if dummy_class not in class_names:
309+
class_names.insert(0, dummy_class)
310+
with open(class_name_file, "w") as f:
311+
f.writelines(class_names)
312+
301313

302314
def _process_huggingface(
303315
model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"

0 commit comments

Comments
 (0)