|
2 | 2 | import json |
3 | 3 | import logging |
4 | 4 | import os |
| 5 | +import re |
| 6 | +import unicodedata |
5 | 7 | from pathlib import Path |
6 | 8 |
|
7 | 9 | from dotenv import load_dotenv |
| 10 | +from Levenshtein import distance |
8 | 11 | from pydantic import TypeAdapter |
9 | 12 |
|
10 | 13 | from src.page_classes import PageClasses |
@@ -53,6 +56,43 @@ def groundtruth_doc_to_pages(documents: list[DocumentGroundTruth]) -> dict[str, |
53 | 56 | return {f"{doc.filename}-{page.page}": page for doc in documents for page in doc.pages} |
54 | 57 |
|
55 | 58 |
|
| 59 | +def standardize_text(text: str) -> str: |
| 60 | + """Standardize text by removing new lines, double spaces and uppercaps. |
| 61 | +
|
| 62 | + Args: |
| 63 | + text (str): Text to standardize. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + str: Standardized text. |
| 67 | + """ |
| 68 | + # Remove new lines |
| 69 | + text = text.replace("\n", " ") |
| 70 | + # Remove double spaces |
| 71 | + text = re.sub(r"\s+", " ", text).strip() |
| 72 | + # Remove accents "ü" -> "u" |
| 73 | + text = "".join(c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn") |
| 74 | + # Enforce lowercases |
| 75 | + return text.lower() |
| 76 | + |
| 77 | + |
| 78 | +def are_texts_close(text_gt: str, text_pred: str, r_error: float = 0.25) -> bool: |
| 79 | + """Check if two texts are similar based on Levenshtein distance. |
| 80 | +
|
| 81 | + Before matching the tiles are standardized. |
| 82 | +
|
| 83 | + Args: |
| 84 | + text_gt (str): Ground truth text. |
| 85 | + text_pred (str): Predicted text. |
| 86 | + r_error (float, optional): Accepted relative error. Defaults to 1e-1. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + bool: True if both text are consifered close to eachothers. |
| 90 | + """ |
| 91 | + text_gt = standardize_text(text_gt) |
| 92 | + text_pred = standardize_text(text_pred) |
| 93 | + return distance(text_gt, text_pred) / max(1, len(text_gt)) < r_error |
| 94 | + |
| 95 | + |
56 | 96 | def compute_classification_stats(predictions: dict[str, DocumentPage], ground_truth: dict[str, DocumentPage]) -> dict: |
57 | 97 | """Compute per-label classification confusion statistics over matched page keys. |
58 | 98 |
|
@@ -101,15 +141,16 @@ def compute_title_stats(predictions: dict[str, DocumentPage], ground_truth: dict |
101 | 141 | for key in common_keys: |
102 | 142 | pred_title = predictions[key].title |
103 | 143 | gt_title = ground_truth[key].title |
104 | | - logger.info(f"{key}: {gt_title} == {pred_title}") |
105 | 144 | # Check if GT exists |
106 | 145 | if not gt_title: |
107 | 146 | continue |
108 | 147 |
|
109 | 148 | # Measure |
110 | | - if pred_title == gt_title: |
| 149 | + if pred_title and are_texts_close(gt_title, pred_title): |
111 | 150 | stats["true_positives"] += 1 |
112 | 151 | else: |
| 152 | + # TODO: remove before final PR |
| 153 | + logger.info(f"{key}: {gt_title} == {pred_title}") |
113 | 154 | stats["false_positives"] += 1 |
114 | 155 | stats["false_negatives"] += 1 |
115 | 156 |
|
|
0 commit comments