@@ -279,6 +279,7 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
279
279
def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ):
280
280
class_names_path = os .path .join (model_path , "class_names.txt" )
281
281
if os .path .exists (class_names_path ):
282
+ maybe_prepend_dummy_class (class_names_path )
282
283
return class_names_path
283
284
284
285
import torch
@@ -289,6 +290,7 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
289
290
with open (class_names_path , "w" ) as f :
290
291
for class_name in args ["class_names" ]:
291
292
f .write (class_name + "\n " )
293
+ maybe_prepend_dummy_class (class_names_path )
292
294
return class_names_path
293
295
294
296
raise FileNotFoundError (
@@ -298,6 +300,16 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
298
300
f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n "
299
301
)
300
302
303
+ def maybe_prepend_dummy_class (class_name_file : str ):
304
+ with open (class_name_file , "r" ) as f :
305
+ class_names = f .readlines ()
306
+
307
+ dummy_class = "background_class83422\n "
308
+ if dummy_class not in class_names :
309
+ class_names .insert (0 , dummy_class )
310
+ with open (class_name_file , "w" ) as f :
311
+ f .writelines (class_names )
312
+
301
313
302
314
def _process_huggingface (
303
315
model_type : str , model_path : str , filename : str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
0 commit comments