33import logging
44import os
55from pathlib import Path
6- from typing import Any
76
87import pandas as pd
98from dotenv import load_dotenv
9+ from pydantic import TypeAdapter
1010
1111from src .page_classes import PageClasses
12+ from src .page_structure import ProcessorDocumentEntities
13+ from src .schemas import DocumentGroundTruth
1214
1315load_dotenv ()
1416mlflow_tracking = os .getenv ("MLFLOW_TRACKING" ).lower () == "true"
2123LABELS = [cls .value for cls in PageClasses ]
2224
2325
24- def load_predictions (predictions : list [dict [str , Any ]]) -> dict [tuple [str , int ], dict [str , int ]]:
25- """Normalizes predictions list.
26-
27- { (filename, page_number): classification_dict }
28- Works for both model predictions and ground-truth lists.
29- """
30- pred_dict : dict [tuple [str , int ], dict [str , int ]] = {}
31-
32- for entry in predictions :
33- filename = entry .get ("filename" )
34- pages = entry .get ("pages" , [])
35-
36- for page_entry in pages :
37- page_number = page_entry .get ("page" )
38- classification = page_entry .get ("classification" )
39-
40- key = (filename , page_number )
41- if key in pred_dict :
42- logger .warning (f"Duplicate entry for { key } ; overwriting previous value." )
43- pred_dict [key ] = classification
44- return pred_dict
45-
46-
47- def load_ground_truth (ground_truth_path : Path ) -> dict | None :
26+ def load_ground_truth (ground_truth_path : Path ) -> list [DocumentGroundTruth ] | None :
4827 """Loads ground truth data from a JSON file."""
4928 try :
5029 with open (ground_truth_path ) as f :
5130 gt_list = json .load (f )
52- return load_predictions (gt_list )
31+ gt_list = TypeAdapter (list [DocumentGroundTruth ]).validate_python (gt_list )
32+ return gt_list
5333 except Exception as e :
5434 logger .error (f"Invalid ground truth path or JSON: { e } " )
5535 return None
5636
5737
58- def compute_confusion_stats (predictions : dict , ground_truth : dict ) -> tuple [dict , int , int ]:
38+ def compute_confusion_stats (
39+ predictions : list [DocumentGroundTruth ], ground_truth : list [DocumentGroundTruth ]
40+ ) -> tuple [dict , int , int ]:
5941 """Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
6042 stats = {label : {"true_positives" : 0 , "false_negatives" : 0 , "false_positives" : 0 } for label in LABELS }
6143
62- pred_keys = set (predictions . keys () )
63- gt_keys = set (ground_truth . keys () )
44+ pred_keys = set ([ pred . filename for pred in predictions ] )
45+ gt_keys = set ([ pred . filename for pred in ground_truth ] )
6446
6547 # Evaluate on the intersection so we don't crash when pages are missing
6648 common_keys = pred_keys & gt_keys
@@ -72,6 +54,7 @@ def compute_confusion_stats(predictions: dict, ground_truth: dict) -> tuple[dict
7254 if missing_in_gt :
7355 logger .info (f"{ len (missing_in_gt )} predicted pages missing in GT (e.g., { next (iter (missing_in_gt ))} )." )
7456
57+ # TODO from here - finish evaluation
7558 total_pages = len (common_keys )
7659 total_files = len ({fname for (fname , _page ) in common_keys })
7760
@@ -234,24 +217,23 @@ def get_active_labels(row, suffix):
234217
235218
236219def evaluate_results (
237- predictions : list [dict ], ground_truth_path : Path , output_dir : Path = Path ("evaluation" )
220+ predictions : list [ProcessorDocumentEntities ], ground_truth_path : Path , output_dir : Path = Path ("evaluation" )
238221) -> dict | None :
239222 """Evaluate classification predictions against ground truth."""
240223 output_dir .mkdir (parents = True , exist_ok = True )
241224
242- gt_dict = load_ground_truth (ground_truth_path )
243- if gt_dict is None :
244- return None
245-
246- pred_dict = load_predictions (predictions )
225+ gt_list = load_ground_truth (ground_truth_path )
226+ pred_list = [pred .to_ground_truth () for pred in predictions ]
247227
248- stats , total_files , total_pages = compute_confusion_stats (pred_dict , gt_dict )
228+ stats , total_files , total_pages = compute_confusion_stats (pred_list , gt_list )
249229 stats_path = save_confusion_stats (stats , output_dir )
250230
251231 if mlflow_tracking :
252232 log_metrics_to_mlflow (stats , total_files , total_pages )
253233 mlflow .log_artifact (str (stats_path ))
254- comparison_data = create_page_comparison (pred_dict , gt_dict , output_dir )
255- save_misclassifications (comparison_data , output_dir )
234+
235+ # TODO
236+ # comparison_data = create_page_comparison(pred_dict, gt_dict, output_dir)
237+ # save_misclassifications(comparison_data, output_dir)
256238
257239 return stats
0 commit comments