@@ -250,6 +250,10 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
250
250
if pt_file is None :
251
251
raise RuntimeError ("No .pt model file found in the provided path" )
252
252
253
+
254
+ class_names = get_classnames_txt_for_rfdetr (model_path , pt_file )
255
+
256
+
253
257
# Copy the .pt file to weights.pt if not already named weights.pt
254
258
if pt_file != "weights.pt" :
255
259
shutil .copy (os .path .join (model_path , pt_file ), os .path .join (model_path , "weights.pt" ))
@@ -269,6 +273,26 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
269
273
270
274
return zip_file_name
271
275
276
+ def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ) -> list [str ]:
277
+ class_names_path = os .path .join (model_path , "class_names.txt" )
278
+ if os .path .exists (class_names_path ):
279
+ return class_names_path
280
+
281
+ model = torch .load (os .path .join (model_path , pt_file ), map_location = "cpu" , weights_only = False )
282
+ args = vars (model ["args" ])
283
+ if "class_names" in args :
284
+ with open (class_names_path , "w" ) as f :
285
+ for class_name in args ["class_names" ]:
286
+ f .write (class_name + "\n " )
287
+ return class_names_path
288
+
289
+ raise FileNotFoundError (
290
+ f"No class_names.txt file found in model path { model_path } .\n "
291
+ f"This should only happen on rfdetr models trained before version 1.1.0.\n "
292
+ f"Please re-train your model with the latest version of the rfdetr library, or\n "
293
+ f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n "
294
+ )
295
+
272
296
273
297
def _process_huggingface (
274
298
model_type : str , model_path : str , filename : str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
0 commit comments