Skip to content

Commit d95fc4a

Browse files
add todo for tracking
1 parent 0bba3e7 commit d95fc4a

File tree

1 file changed

+20
-38
lines changed

1 file changed

+20
-38
lines changed

src/evaluation.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44
import os
55
from pathlib import Path
6-
from typing import Any
76

87
import pandas as pd
98
from dotenv import load_dotenv
9+
from pydantic import TypeAdapter
1010

1111
from src.page_classes import PageClasses
12+
from src.page_structure import ProcessorDocumentEntities
13+
from src.schemas import DocumentGroundTruth
1214

1315
load_dotenv()
1416
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
@@ -21,46 +23,26 @@
2123
LABELS = [cls.value for cls in PageClasses]
2224

2325

24-
def load_predictions(predictions: list[dict[str, Any]]) -> dict[tuple[str, int], dict[str, int]]:
25-
"""Normalizes predictions list.
26-
27-
{ (filename, page_number): classification_dict }
28-
Works for both model predictions and ground-truth lists.
29-
"""
30-
pred_dict: dict[tuple[str, int], dict[str, int]] = {}
31-
32-
for entry in predictions:
33-
filename = entry.get("filename")
34-
pages = entry.get("pages", [])
35-
36-
for page_entry in pages:
37-
page_number = page_entry.get("page")
38-
classification = page_entry.get("classification")
39-
40-
key = (filename, page_number)
41-
if key in pred_dict:
42-
logger.warning(f"Duplicate entry for {key}; overwriting previous value.")
43-
pred_dict[key] = classification
44-
return pred_dict
45-
46-
47-
def load_ground_truth(ground_truth_path: Path) -> dict | None:
26+
def load_ground_truth(ground_truth_path: Path) -> list[DocumentGroundTruth] | None:
4827
"""Loads ground truth data from a JSON file."""
4928
try:
5029
with open(ground_truth_path) as f:
5130
gt_list = json.load(f)
52-
return load_predictions(gt_list)
31+
gt_list = TypeAdapter(list[DocumentGroundTruth]).validate_python(gt_list)
32+
return gt_list
5333
except Exception as e:
5434
logger.error(f"Invalid ground truth path or JSON: {e}")
5535
return None
5636

5737

58-
def compute_confusion_stats(predictions: dict, ground_truth: dict) -> tuple[dict, int, int]:
38+
def compute_confusion_stats(
39+
predictions: list[DocumentGroundTruth], ground_truth: list[DocumentGroundTruth]
40+
) -> tuple[dict, int, int]:
5941
"""Computes confusion matrix entries, total pages and files processed for evaluating classification results."""
6042
stats = {label: {"true_positives": 0, "false_negatives": 0, "false_positives": 0} for label in LABELS}
6143

62-
pred_keys = set(predictions.keys())
63-
gt_keys = set(ground_truth.keys())
44+
pred_keys = set([pred.filename for pred in predictions])
45+
gt_keys = set([pred.filename for pred in ground_truth])
6446

6547
# Evaluate on the intersection so we don't crash when pages are missing
6648
common_keys = pred_keys & gt_keys
@@ -72,6 +54,7 @@ def compute_confusion_stats(predictions: dict, ground_truth: dict) -> tuple[dict
7254
if missing_in_gt:
7355
logger.info(f"{len(missing_in_gt)} predicted pages missing in GT (e.g., {next(iter(missing_in_gt))}).")
7456

57+
# TODO from here - finish evaluation
7558
total_pages = len(common_keys)
7659
total_files = len({fname for (fname, _page) in common_keys})
7760

@@ -234,24 +217,23 @@ def get_active_labels(row, suffix):
234217

235218

236219
def evaluate_results(
237-
predictions: list[dict], ground_truth_path: Path, output_dir: Path = Path("evaluation")
220+
predictions: list[ProcessorDocumentEntities], ground_truth_path: Path, output_dir: Path = Path("evaluation")
238221
) -> dict | None:
239222
"""Evaluate classification predictions against ground truth."""
240223
output_dir.mkdir(parents=True, exist_ok=True)
241224

242-
gt_dict = load_ground_truth(ground_truth_path)
243-
if gt_dict is None:
244-
return None
245-
246-
pred_dict = load_predictions(predictions)
225+
gt_list = load_ground_truth(ground_truth_path)
226+
pred_list = [pred.to_ground_truth() for pred in predictions]
247227

248-
stats, total_files, total_pages = compute_confusion_stats(pred_dict, gt_dict)
228+
stats, total_files, total_pages = compute_confusion_stats(pred_list, gt_list)
249229
stats_path = save_confusion_stats(stats, output_dir)
250230

251231
if mlflow_tracking:
252232
log_metrics_to_mlflow(stats, total_files, total_pages)
253233
mlflow.log_artifact(str(stats_path))
254-
comparison_data = create_page_comparison(pred_dict, gt_dict, output_dir)
255-
save_misclassifications(comparison_data, output_dir)
234+
235+
# TODO
236+
# comparison_data = create_page_comparison(pred_dict, gt_dict, output_dir)
237+
# save_misclassifications(comparison_data, output_dir)
256238

257239
return stats

0 commit comments

Comments
 (0)