@@ -279,6 +279,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
279279def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ):
280280 class_names_path = os .path .join (model_path , "class_names.txt" )
281281 if os .path .exists (class_names_path ):
282+ maybe_prepend_dummy_class (class_names_path )
282283 return class_names_path
283284
284285 import torch
@@ -289,6 +290,7 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
289290 with open (class_names_path , "w" ) as f :
290291 for class_name in args ["class_names" ]:
291292 f .write (class_name + "\n " )
293+ maybe_prepend_dummy_class (class_names_path )
292294 return class_names_path
293295
294296 raise FileNotFoundError (
@@ -298,6 +300,16 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
298300 f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n "
299301 )
300302
303+ def maybe_prepend_dummy_class (class_name_file : str ):
304+ with open (class_name_file , "r" ) as f :
305+ class_names = f .readlines ()
306+
307+ dummy_class = "background_class83422\n "
308+ if dummy_class not in class_names :
309+ class_names .insert (0 , dummy_class )
310+ with open (class_name_file , "w" ) as f :
311+ f .writelines (class_names )
312+
301313
302314def _process_huggingface (
303315 model_type : str , model_path : str , filename : str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
0 commit comments