66
77import pandas as pd
88from dotenv import load_dotenv
9+ from extraction .evaluation .benchmark .metrics import OverallMetrics
10+ from extraction .evaluation .evaluation_dataclasses import Metrics
11+ from extraction .evaluation .utility import evaluate_single
912from pydantic import TypeAdapter
1013
1114from src .page_classes import PageClasses
1215from src .page_structure import ProcessorDocumentEntities
13- from src .schemas import DocumentGroundTruth
16+ from src .schemas import DocumentGroundTruth , DocumentPage
1417
1518load_dotenv ()
1619mlflow_tracking = os .getenv ("MLFLOW_TRACKING" ).lower () == "true"
@@ -35,35 +38,37 @@ def load_ground_truth(ground_truth_path: Path) -> list[DocumentGroundTruth] | No
3538 return None
3639
3740
38- def compute_confusion_stats (
39- predictions : list [DocumentGroundTruth ], ground_truth : list [DocumentGroundTruth ]
40- ) -> tuple [dict , int , int ]:
41- """Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
42- stats = {label : {"true_positives" : 0 , "false_negatives" : 0 , "false_positives" : 0 } for label in LABELS }
41+ def groundtruth_doc_to_pages (documents : list [DocumentGroundTruth ]) -> dict [str , DocumentPage ]:
42+ """Convert list of documents to list of keyed pages.
4343
44- pred_keys = set ([ pred . filename for pred in predictions ])
45- gt_keys = set ([ pred . filename for pred in ground_truth ])
44+ Args:
45+ documents (list[DocumentGroundTruth]): Documents with pages to flatten
4646
47- # Evaluate on the intersection so we don't crash when pages are missing
48- common_keys = pred_keys & gt_keys
47+ Returns:
48+ dict[str, DocumentPage]: Keyed pages.
49+ """
50+ return {f"{ doc .filename } -{ page .page } " : page for doc in documents for page in doc .pages }
4951
50- missing_in_pred = gt_keys - pred_keys
51- missing_in_gt = pred_keys - gt_keys
52- if missing_in_pred :
53- logger .info (f"{ len (missing_in_pred )} GT pages have no prediction (e.g., { next (iter (missing_in_pred ))} )." )
54- if missing_in_gt :
55- logger .info (f"{ len (missing_in_gt )} predicted pages missing in GT (e.g., { next (iter (missing_in_gt ))} )." )
5652
57- # TODO from here - finish evaluation
58- total_pages = len (common_keys )
59- total_files = len ({fname for (fname , _page ) in common_keys })
53+ # def compute_title_metric(title_gt: str, title_pred: str) -> bool | None:
54+ # if title_gt is None:
55+ # return None
56+ # else:
57+ # return title_gt.lower().strip() == title_pred.lower().strip()
58+
59+
60+ def compute_classification_stats (
61+ predictions : dict [str , DocumentGroundTruth ], ground_truth : dict [str , DocumentGroundTruth ]
62+ ) -> dict :
63+ stats = {label : {"true_positives" : 0 , "false_negatives" : 0 , "false_positives" : 0 } for label in LABELS }
64+ common_keys = predictions .keys () & ground_truth .keys ()
6065
6166 for key in common_keys :
6267 pred_page = predictions .get (key , {})
6368 gt_page = ground_truth .get (key , {})
6469 for label in LABELS :
65- pred = int (pred_page .get (label , 0 ))
66- gt = int (gt_page .get (label , 0 ))
70+ pred = int (pred_page .classification . get (label , 0 ))
71+ gt = int (gt_page .classification . get (label , 0 ))
6772
6873 if gt == 1 and pred == 1 :
6974 stats [label ]["true_positives" ] += 1
@@ -72,24 +77,64 @@ def compute_confusion_stats(
7277 elif gt == 0 and pred == 1 :
7378 stats [label ]["false_positives" ] += 1
7479
75- return stats , total_files , total_pages
80+ return stats
7681
7782
78- def save_confusion_stats (stats : dict , output_dir : Path ) -> Path :
79- """Saves confusion matrix to output directory."""
80- csv_path = output_dir / "evaluation_metrics.csv"
83+ def compute_title_stats (
84+ predictions : dict [str , DocumentGroundTruth ], ground_truth : dict [str , DocumentGroundTruth ]
85+ ) -> dict :
86+ stats = {"true_positives" : 0 , "false_negatives" : 0 , "false_positives" : 0 }
87+ common_keys = predictions .keys () & ground_truth .keys ()
8188
89+ for key in common_keys :
90+ pred_title = predictions [key ].title
91+ gt_title = ground_truth [key ].title
92+ logger .info (f"{ key } : { gt_title } == { pred_title } " )
93+ if pred_title and gt_title and pred_title == gt_title :
94+ stats ["true_positives" ] += 1
95+ else :
96+ stats ["false_negatives" ] += 1
97+ stats ["false_positives" ] += 1
98+
99+ return {"title" : stats }
100+
101+
102+ def compute_stats (
103+ predictions : list [DocumentGroundTruth ], ground_truths : list [DocumentGroundTruth ]
104+ ) -> tuple [dict , dict ]:
105+ """Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
106+ pred_keyed = groundtruth_doc_to_pages (predictions )
107+ gt_keyed = groundtruth_doc_to_pages (ground_truths )
108+
109+ # Evaluate on the intersection so we don't crash when pages are missing
110+ pred_keys , gt_keys = set (pred_keyed .keys ()), set (gt_keyed .keys ())
111+
112+ missing_in_pred = gt_keys - pred_keys
113+ missing_in_gt = pred_keys - gt_keys
114+ if missing_in_pred :
115+ logger .info (f"{ len (missing_in_pred )} GT pages have no prediction (e.g., { next (iter (missing_in_pred ))} )." )
116+ if missing_in_gt :
117+ logger .info (f"{ len (missing_in_gt )} predicted pages missing in GT (e.g., { next (iter (missing_in_gt ))} )." )
118+
119+ classification_stats = compute_classification_stats (pred_keyed , gt_keyed )
120+ title_stats = compute_title_stats (pred_keyed , gt_keyed )
121+
122+ return classification_stats , title_stats
123+
124+
125+ def save_stats (stats_classification : list , csv_path : Path ) -> Path :
126+ """Saves confusion matrix to output directory."""
82127 with open (csv_path , "w" , newline = "" ) as f :
83128 writer = csv .writer (f )
84129 writer .writerow (
85130 [
86- "Class " ,
131+ "Label " ,
87132 "True_Positives" ,
88133 "False_Negatives" ,
89134 "False_Positives" ,
90135 ]
91136 )
92- for label , s in stats .items ():
137+ for label , s in stats_classification .items ():
93138 writer .writerow (
94139 [
95140 label ,
@@ -101,15 +146,27 @@ def save_confusion_stats(stats: dict, output_dir: Path) -> Path:
101146 return csv_path
102147
103148
104- def log_metrics_to_mlflow (stats : dict , total_files : int , total_pages : int ) -> None :
149+ def log_metrics_to_mlflow (stats_classification : dict , stats_title : dict ) -> None :
105150 """Calculates and logs F1, precision and recall to MLflow."""
106151 if not mlflow_tracking :
107152 return None
108153
154+ # Log metrics for title extraction
155+ tp , fp , fn = [stats_title ["title" ][label ] for label in ["true_positives" , "false_negatives" , "false_positives" ]]
156+ precision = tp / (tp + fp ) if (tp + fp ) else 0.0
157+ recall = tp / (tp + fn ) if (tp + fn ) else 0.0
158+ f1 = 2 * (precision * recall ) / (precision + recall ) if (precision + recall ) else 0.0
159+ mlflow .log_metric ("title/f1" , f1 )
160+ mlflow .log_metric ("title/precision" , precision )
161+ mlflow .log_metric ("title/recall" , recall )
162+
163+ logger .info (f"Title: F1={ f1 :.2%} , Precision={ precision :.2%} , Recall={ recall :.2%} " )
164+
165+ # Log metrics for classification output
109166 precisions = []
110167 recalls = []
111168 f1_scores = []
112- for label , s in stats .items ():
169+ for label , s in stats_classification .items ():
113170 tp , fn , fp = s ["true_positives" ], s ["false_negatives" ], s ["false_positives" ]
114171 precision = tp / (tp + fp ) if (tp + fp ) else 0.0
115172 recall = tp / (tp + fn ) if (tp + fn ) else 0.0
@@ -119,121 +176,39 @@ def log_metrics_to_mlflow(stats: dict, total_files: int, total_pages: int) -> No
119176 recalls .append (recall )
120177 f1_scores .append (f1 )
121178
122- mlflow .log_metric (f"F1 { label } " , f1 )
123- mlflow .log_metric (f"{ label } _precision" , precision )
124- mlflow .log_metric (f"{ label } _recall" , recall )
179+ mlflow .log_metric (f"classification/ { label } _f1 " , f1 )
180+ mlflow .log_metric (f"classification/ { label } _precision" , precision )
181+ mlflow .log_metric (f"classification/ { label } _recall" , recall )
125182
126183 logger .info (f"{ label } : F1={ f1 :.2%} , Precision={ precision :.2%} , Recall={ recall :.2%} " )
127184
128185 macro_precision = sum (precisions ) / len (precisions ) if precisions else 0.0
129186 macro_recall = sum (recalls ) / len (recalls ) if recalls else 0.0
130187 macro_f1 = sum (f1_scores ) / len (f1_scores ) if f1_scores else 0.0
131188
132- mlflow .log_metric ("Macro Avg Precision" , macro_precision )
133- mlflow .log_metric ("Macro Avg Recall" , macro_recall )
134- mlflow .log_metric ("Macro Avg F1" , macro_f1 )
135-
136- logger .info (f"Macro Avg: F1={ macro_f1 :.2%} , Precision={ macro_precision :.2%} , Recall={ macro_recall :.2%} " )
137-
138- mlflow .log_metric ("total_pages" , total_pages )
139- mlflow .log_metric ("total_files" , total_files )
140-
141-
142- def create_page_comparison (pred_dict : dict , gt_dict : dict , output_dir : Path ) -> pd .DataFrame :
143- """Create a per-page comparison CSV/DF for pages present in both predictions and ground truth (intersection)."""
144- output_dir .mkdir (parents = True , exist_ok = True )
145- report_path = output_dir / "per_page_comparison.csv"
146-
147- columns = (
148- ["Filename" , "Page" ]
149- + [f"{ label } _pred" for label in LABELS ]
150- + [f"{ label } _gt" for label in LABELS ]
151- + [f"{ label } _match" for label in LABELS ]
152- + ["All_labels_match" , "Status" ]
153- )
154-
155- pred_keys = set (pred_dict .keys ())
156- gt_keys = set (gt_dict .keys ())
157-
158- # Only evaluate files/pages that are in predictions.
159- keys = pred_keys & gt_keys
160-
161- rows = []
162- for filename , page_num in sorted (keys , key = lambda k : (k [0 ], k [1 ])):
163- pred_page = pred_dict [(filename , page_num )]
164- gt_page = gt_dict [(filename , page_num )]
165-
166- preds = [int (pred_page .get (label , 0 )) for label in LABELS ]
167- gts = [int (gt_page .get (label , 0 )) for label in LABELS ]
168- matches = [int (p == g ) for p , g in zip (preds , gts , strict = True )]
169- all_match = int (all (matches ))
170-
171- # Only keep misclassifications
172- if not all_match :
173- status = "mismatch"
174- row = [filename , page_num ] + preds + gts + matches + [all_match , status ]
175- rows .append (row )
176-
177- df = pd .DataFrame (rows , columns = columns )
178-
179- # Write to csv file
180- with open (report_path , "w" , newline = "" , encoding = "utf-8" ) as f :
181- writer = csv .writer (f )
182- writer .writerow (columns )
183- writer .writerows (rows )
184-
185- if mlflow_tracking :
186- mlflow .log_artifact (str (report_path ))
187- logger .info (f"Logged misclassifications to { report_path } " )
188-
189- return df
190-
191-
192- def save_misclassifications (df : pd .DataFrame , output_dir : Path ) -> None :
193- """Save misclassified pages and per-class CSVs."""
194-
195- def get_active_labels (row , suffix ):
196- return [label for label in LABELS if row [f"{ label } _{ suffix } " ] == 1 ]
189+ mlflow .log_metric ("classification/macro_precision" , macro_precision )
190+ mlflow .log_metric ("classification/macro_recall" , macro_recall )
191+ mlflow .log_metric ("classification/marco_f1" , macro_f1 )
197192
198- df ["Predicted_labels" ] = df .apply (lambda row : get_active_labels (row , suffix = "pred" ), axis = 1 )
199- df ["Ground_truth_labels" ] = df .apply (lambda row : get_active_labels (row , suffix = "gt" ), axis = 1 )
200-
201- misclassified = df [df ["All_labels_match" ] == 0 ][["Filename" , "Page" , "Ground_truth_labels" , "Predicted_labels" ]]
202-
203- mis_path = output_dir / "misclassifications.csv"
204- misclassified .to_csv (mis_path , index = False )
205- if mlflow_tracking :
206- mlflow .log_artifact (str (mis_path ))
207-
208- for true_class in LABELS :
209- class_mis = misclassified [
210- misclassified ["Ground_truth_labels" ].apply (lambda labels , cls = true_class : cls in labels )
211- ]
212- if not class_mis .empty :
213- path = output_dir / f"misclassified_{ true_class } .csv"
214- class_mis .to_csv (path , index = False )
215- if mlflow_tracking :
216- mlflow .log_artifact (str (path ))
193+ logger .info (f"Classification Macro: F1={ macro_f1 :.2%} , Precision={ macro_precision :.2%} , Recall={ macro_recall :.2%} " )
217194
218195
219196def evaluate_results (
220197 predictions : list [ProcessorDocumentEntities ], ground_truth_path : Path , output_dir : Path = Path ("evaluation" )
221- ) -> dict | None :
198+ ) -> tuple [ Path , Path ] :
222199 """Evaluate classification predictions against ground truth."""
223200 output_dir .mkdir (parents = True , exist_ok = True )
224201
225202 gt_list = load_ground_truth (ground_truth_path )
226203 pred_list = [pred .to_ground_truth () for pred in predictions ]
227204
228- stats , total_files , total_pages = compute_confusion_stats (pred_list , gt_list )
229- stats_path = save_confusion_stats (stats , output_dir )
205+ stats_classification , stats_title = compute_stats (pred_list , gt_list )
206+ stats_classification_path = save_stats (stats_classification , output_dir / "evaluation_metrics_classification.csv" )
207+ stats_title_path = save_stats (stats_title , output_dir / "evaluation_metrics_title.csv" )
230208
231209 if mlflow_tracking :
232- log_metrics_to_mlflow (stats , total_files , total_pages )
233- mlflow .log_artifact (str (stats_path ))
234-
235- # TODO
236- # comparison_data = create_page_comparison(pred_dict, gt_dict, output_dir)
237- # save_misclassifications(comparison_data, output_dir)
210+ log_metrics_to_mlflow (stats_classification , stats_title )
211+ mlflow .log_artifact (str (stats_classification_path ))
212+ mlflow .log_artifact (str (stats_title_path ))
238213
239- return stats
214+ return stats_classification , stats_title_path
0 commit comments