@@ -250,6 +250,10 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
250250 if pt_file is None :
251251 raise RuntimeError ("No .pt model file found in the provided path" )
252252
253+
254+ class_names = get_classnames_txt_for_rfdetr (model_path , pt_file )
255+
256+
253257 # Copy the .pt file to weights.pt if not already named weights.pt
254258 if pt_file != "weights.pt" :
255259 shutil .copy (os .path .join (model_path , pt_file ), os .path .join (model_path , "weights.pt" ))
@@ -269,6 +273,26 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
269273
270274 return zip_file_name
271275
276+ def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ) -> list [str ]:
277+ class_names_path = os .path .join (model_path , "class_names.txt" )
278+ if os .path .exists (class_names_path ):
279+ return class_names_path
280+
281+ model = torch .load (os .path .join (model_path , pt_file ), map_location = "cpu" , weights_only = False )
282+ args = vars (model ["args" ])
283+ if "class_names" in args :
284+ with open (class_names_path , "w" ) as f :
285+ for class_name in args ["class_names" ]:
286+ f .write (class_name + "\n " )
287+ return class_names_path
288+
289+ raise FileNotFoundError (
290+ f"No class_names.txt file found in model path { model_path } .\n "
291+ f"This should only happen on rfdetr models trained before version 1.1.0.\n "
292+ f"Please re-train your model with the latest version of the rfdetr library, or\n "
293+ f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n "
294+ )
295+
272296
273297def _process_huggingface (
274298 model_type : str , model_path : str , filename : str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
0 commit comments