Skip to content

Commit 6ded732

Browse files
baseline for page evaluation and cleaning of metrics section
1 parent d95fc4a commit 6ded732

File tree

2 files changed

+102
-127
lines changed

2 files changed

+102
-127
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def main(
302302

303303
# Start MLFlow tracking
304304
if mlflow_tracking:
305-
setup_mlflow(input_path, ground_truth_path, model_path, matching_params, classifier_name)
305+
setup_mlflow(input_path, matching_params, ground_truth_path, model_path, classifier_name)
306306

307307
# Process pages
308308
pdf_files = get_pdf_files(input_path)

src/evaluation.py

Lines changed: 101 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77
import pandas as pd
88
from dotenv import load_dotenv
9+
from extraction.evaluation.benchmark.metrics import OverallMetrics
10+
from extraction.evaluation.evaluation_dataclasses import Metrics
11+
from extraction.evaluation.utility import evaluate_single
912
from pydantic import TypeAdapter
1013

1114
from src.page_classes import PageClasses
1215
from src.page_structure import ProcessorDocumentEntities
13-
from src.schemas import DocumentGroundTruth
16+
from src.schemas import DocumentGroundTruth, DocumentPage
1417

1518
load_dotenv()
1619
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
@@ -35,35 +38,37 @@ def load_ground_truth(ground_truth_path: Path) -> list[DocumentGroundTruth] | No
3538
return None
3639

3740

38-
def compute_confusion_stats(
39-
predictions: list[DocumentGroundTruth], ground_truth: list[DocumentGroundTruth]
40-
) -> tuple[dict, int, int]:
41-
"""Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
42-
stats = {label: {"true_positives": 0, "false_negatives": 0, "false_positives": 0} for label in LABELS}
41+
def groundtruth_doc_to_pages(documents: list[DocumentGroundTruth]) -> dict[str, DocumentPage]:
42+
"""Convert list of documents to list of keyed pages.
4343
44-
pred_keys = set([pred.filename for pred in predictions])
45-
gt_keys = set([pred.filename for pred in ground_truth])
44+
Args:
45+
documents (list[DocumentGroundTruth]): Documents with pages to flatten
4646
47-
# Evaluate on the intersection so we don't crash when pages are missing
48-
common_keys = pred_keys & gt_keys
47+
Returns:
48+
dict[str, DocumentPage]: Keyed pages.
49+
"""
50+
return {f"{doc.filename}-{page.page}": page for doc in documents for page in doc.pages}
4951

50-
missing_in_pred = gt_keys - pred_keys
51-
missing_in_gt = pred_keys - gt_keys
52-
if missing_in_pred:
53-
logger.info(f"{len(missing_in_pred)} GT pages have no prediction (e.g., {next(iter(missing_in_pred))}).")
54-
if missing_in_gt:
55-
logger.info(f"{len(missing_in_gt)} predicted pages missing in GT (e.g., {next(iter(missing_in_gt))}).")
5652

57-
# TODO from here - finish evaluation
58-
total_pages = len(common_keys)
59-
total_files = len({fname for (fname, _page) in common_keys})
53+
# def compute_title_metric(title_gt: str, title_pred: str) -> bool | None:
54+
# if title_gt is None:
55+
# return None
56+
# else:
57+
# return title_gt.lower().strip() == title_pred.lower().strip()
58+
59+
60+
def compute_classification_stats(
61+
predictions: dict[str, DocumentGroundTruth], ground_truth: dict[str, DocumentGroundTruth]
62+
) -> dict:
63+
stats = {label: {"true_positives": 0, "false_negatives": 0, "false_positives": 0} for label in LABELS}
64+
common_keys = predictions.keys() & ground_truth.keys()
6065

6166
for key in common_keys:
6267
pred_page = predictions.get(key, {})
6368
gt_page = ground_truth.get(key, {})
6469
for label in LABELS:
65-
pred = int(pred_page.get(label, 0))
66-
gt = int(gt_page.get(label, 0))
70+
pred = int(pred_page.classification.get(label, 0))
71+
gt = int(gt_page.classification.get(label, 0))
6772

6873
if gt == 1 and pred == 1:
6974
stats[label]["true_positives"] += 1
@@ -72,24 +77,64 @@ def compute_confusion_stats(
7277
elif gt == 0 and pred == 1:
7378
stats[label]["false_positives"] += 1
7479

75-
return stats, total_files, total_pages
80+
return stats
7681

7782

78-
def save_confusion_stats(stats: dict, output_dir: Path) -> Path:
79-
"""Saves confusion matrix to output directory."""
80-
csv_path = output_dir / "evaluation_metrics.csv"
83+
def compute_title_stats(
84+
predictions: dict[str, DocumentGroundTruth], ground_truth: dict[str, DocumentGroundTruth]
85+
) -> dict:
86+
stats = {"true_positives": 0, "false_negatives": 0, "false_positives": 0}
87+
common_keys = predictions.keys() & ground_truth.keys()
8188

89+
for key in common_keys:
90+
pred_title = predictions[key].title
91+
gt_title = ground_truth[key].title
92+
logger.info(f"{key}: {gt_title} == {pred_title}")
93+
if pred_title and gt_title and pred_title == gt_title:
94+
stats["true_positives"] += 1
95+
else:
96+
stats["false_negatives"] += 1
97+
stats["false_positives"] += 1
98+
99+
return {"title": stats}
100+
101+
102+
def compute_stats(
103+
predictions: list[DocumentGroundTruth], ground_truths: list[DocumentGroundTruth]
104+
) -> tuple[dict, dict]:
105+
"""Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
106+
pred_keyed = groundtruth_doc_to_pages(predictions)
107+
gt_keyed = groundtruth_doc_to_pages(ground_truths)
108+
109+
# Evaluate on the intersection so we don't crash when pages are missing
110+
pred_keys, gt_keys = set(pred_keyed.keys()), set(gt_keyed.keys())
111+
112+
missing_in_pred = gt_keys - pred_keys
113+
missing_in_gt = pred_keys - gt_keys
114+
if missing_in_pred:
115+
logger.info(f"{len(missing_in_pred)} GT pages have no prediction (e.g., {next(iter(missing_in_pred))}).")
116+
if missing_in_gt:
117+
logger.info(f"{len(missing_in_gt)} predicted pages missing in GT (e.g., {next(iter(missing_in_gt))}).")
118+
119+
classification_stats = compute_classification_stats(pred_keyed, gt_keyed)
120+
title_stats = compute_title_stats(pred_keyed, gt_keyed)
121+
122+
return classification_stats, title_stats
123+
124+
125+
def save_stats(stats_classification: list, csv_path: Path) -> Path:
126+
"""Saves confusion matrix to output directory."""
82127
with open(csv_path, "w", newline="") as f:
83128
writer = csv.writer(f)
84129
writer.writerow(
85130
[
86-
"Class",
131+
"Label",
87132
"True_Positives",
88133
"False_Negatives",
89134
"False_Positives",
90135
]
91136
)
92-
for label, s in stats.items():
137+
for label, s in stats_classification.items():
93138
writer.writerow(
94139
[
95140
label,
@@ -101,15 +146,27 @@ def save_confusion_stats(stats: dict, output_dir: Path) -> Path:
101146
return csv_path
102147

103148

104-
def log_metrics_to_mlflow(stats: dict, total_files: int, total_pages: int) -> None:
149+
def log_metrics_to_mlflow(stats_classification: dict, stats_title: dict) -> None:
105150
"""Calculates and logs F1, precision and recall to MLflow."""
106151
if not mlflow_tracking:
107152
return None
108153

154+
# Log metrics for title extraction
155+
tp, fp, fn = [stats_title["title"][label] for label in ["true_positives", "false_negatives", "false_positives"]]
156+
precision = tp / (tp + fp) if (tp + fp) else 0.0
157+
recall = tp / (tp + fn) if (tp + fn) else 0.0
158+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0.0
159+
mlflow.log_metric("title/f1", f1)
160+
mlflow.log_metric("title/precision", precision)
161+
mlflow.log_metric("title/recall", recall)
162+
163+
logger.info(f"Title: F1={f1:.2%}, Precision={precision:.2%}, Recall={recall:.2%}")
164+
165+
# Log metrics for classification output
109166
precisions = []
110167
recalls = []
111168
f1_scores = []
112-
for label, s in stats.items():
169+
for label, s in stats_classification.items():
113170
tp, fn, fp = s["true_positives"], s["false_negatives"], s["false_positives"]
114171
precision = tp / (tp + fp) if (tp + fp) else 0.0
115172
recall = tp / (tp + fn) if (tp + fn) else 0.0
@@ -119,121 +176,39 @@ def log_metrics_to_mlflow(stats: dict, total_files: int, total_pages: int) -> No
119176
recalls.append(recall)
120177
f1_scores.append(f1)
121178

122-
mlflow.log_metric(f"F1 {label}", f1)
123-
mlflow.log_metric(f"{label}_precision", precision)
124-
mlflow.log_metric(f"{label}_recall", recall)
179+
mlflow.log_metric(f"classification/{label}_f1", f1)
180+
mlflow.log_metric(f"classification/{label}_precision", precision)
181+
mlflow.log_metric(f"classification/{label}_recall", recall)
125182

126183
logger.info(f"{label}: F1={f1:.2%}, Precision={precision:.2%}, Recall={recall:.2%}")
127184

128185
macro_precision = sum(precisions) / len(precisions) if precisions else 0.0
129186
macro_recall = sum(recalls) / len(recalls) if recalls else 0.0
130187
macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
131188

132-
mlflow.log_metric("Macro Avg Precision", macro_precision)
133-
mlflow.log_metric("Macro Avg Recall", macro_recall)
134-
mlflow.log_metric("Macro Avg F1", macro_f1)
135-
136-
logger.info(f"Macro Avg: F1={macro_f1:.2%}, Precision={macro_precision:.2%}, Recall={macro_recall:.2%}")
137-
138-
mlflow.log_metric("total_pages", total_pages)
139-
mlflow.log_metric("total_files", total_files)
140-
141-
142-
def create_page_comparison(pred_dict: dict, gt_dict: dict, output_dir: Path) -> pd.DataFrame:
143-
"""Create a per-page comparison CSV/DF for pages present in both predictions and ground truth (intersection)."""
144-
output_dir.mkdir(parents=True, exist_ok=True)
145-
report_path = output_dir / "per_page_comparison.csv"
146-
147-
columns = (
148-
["Filename", "Page"]
149-
+ [f"{label}_pred" for label in LABELS]
150-
+ [f"{label}_gt" for label in LABELS]
151-
+ [f"{label}_match" for label in LABELS]
152-
+ ["All_labels_match", "Status"]
153-
)
154-
155-
pred_keys = set(pred_dict.keys())
156-
gt_keys = set(gt_dict.keys())
157-
158-
# Only evaluate files/pages that are in predictions.
159-
keys = pred_keys & gt_keys
160-
161-
rows = []
162-
for filename, page_num in sorted(keys, key=lambda k: (k[0], k[1])):
163-
pred_page = pred_dict[(filename, page_num)]
164-
gt_page = gt_dict[(filename, page_num)]
165-
166-
preds = [int(pred_page.get(label, 0)) for label in LABELS]
167-
gts = [int(gt_page.get(label, 0)) for label in LABELS]
168-
matches = [int(p == g) for p, g in zip(preds, gts, strict=True)]
169-
all_match = int(all(matches))
170-
171-
# Only keep misclassifications
172-
if not all_match:
173-
status = "mismatch"
174-
row = [filename, page_num] + preds + gts + matches + [all_match, status]
175-
rows.append(row)
176-
177-
df = pd.DataFrame(rows, columns=columns)
178-
179-
# Write to csv file
180-
with open(report_path, "w", newline="", encoding="utf-8") as f:
181-
writer = csv.writer(f)
182-
writer.writerow(columns)
183-
writer.writerows(rows)
184-
185-
if mlflow_tracking:
186-
mlflow.log_artifact(str(report_path))
187-
logger.info(f"Logged misclassifications to {report_path}")
188-
189-
return df
190-
191-
192-
def save_misclassifications(df: pd.DataFrame, output_dir: Path) -> None:
193-
"""Save misclassified pages and per-class CSVs."""
194-
195-
def get_active_labels(row, suffix):
196-
return [label for label in LABELS if row[f"{label}_{suffix}"] == 1]
189+
mlflow.log_metric("classification/macro_precision", macro_precision)
190+
mlflow.log_metric("classification/macro_recall", macro_recall)
191+
mlflow.log_metric("classification/marco_f1", macro_f1)
197192

198-
df["Predicted_labels"] = df.apply(lambda row: get_active_labels(row, suffix="pred"), axis=1)
199-
df["Ground_truth_labels"] = df.apply(lambda row: get_active_labels(row, suffix="gt"), axis=1)
200-
201-
misclassified = df[df["All_labels_match"] == 0][["Filename", "Page", "Ground_truth_labels", "Predicted_labels"]]
202-
203-
mis_path = output_dir / "misclassifications.csv"
204-
misclassified.to_csv(mis_path, index=False)
205-
if mlflow_tracking:
206-
mlflow.log_artifact(str(mis_path))
207-
208-
for true_class in LABELS:
209-
class_mis = misclassified[
210-
misclassified["Ground_truth_labels"].apply(lambda labels, cls=true_class: cls in labels)
211-
]
212-
if not class_mis.empty:
213-
path = output_dir / f"misclassified_{true_class}.csv"
214-
class_mis.to_csv(path, index=False)
215-
if mlflow_tracking:
216-
mlflow.log_artifact(str(path))
193+
logger.info(f"Classification Macro: F1={macro_f1:.2%}, Precision={macro_precision:.2%}, Recall={macro_recall:.2%}")
217194

218195

219196
def evaluate_results(
220197
predictions: list[ProcessorDocumentEntities], ground_truth_path: Path, output_dir: Path = Path("evaluation")
221-
) -> dict | None:
198+
) -> tuple[Path, Path]:
222199
"""Evaluate classification predictions against ground truth."""
223200
output_dir.mkdir(parents=True, exist_ok=True)
224201

225202
gt_list = load_ground_truth(ground_truth_path)
226203
pred_list = [pred.to_ground_truth() for pred in predictions]
227204

228-
stats, total_files, total_pages = compute_confusion_stats(pred_list, gt_list)
229-
stats_path = save_confusion_stats(stats, output_dir)
205+
stats_classification, stats_title = compute_stats(pred_list, gt_list)
206+
stats_classification_path = save_stats(stats_classification, output_dir / "evaluation_metrics_classification.csv")
207+
stats_title_path = save_stats(stats_title, output_dir / "evaluation_metrics_title.csv")
230208

231209
if mlflow_tracking:
232-
log_metrics_to_mlflow(stats, total_files, total_pages)
233-
mlflow.log_artifact(str(stats_path))
234-
235-
# TODO
236-
# comparison_data = create_page_comparison(pred_dict, gt_dict, output_dir)
237-
# save_misclassifications(comparison_data, output_dir)
210+
log_metrics_to_mlflow(stats_classification, stats_title)
211+
mlflow.log_artifact(str(stats_classification_path))
212+
mlflow.log_artifact(str(stats_title_path))
238213

239-
return stats
214+
return stats_classification, stats_title_path

0 commit comments

Comments
 (0)