Skip to content

Commit f003761

Browse files
Merge pull request #424 from swisstopo/feat/issue-398/red-green-markers-evaluation
Red / green marker for evaluation
2 parents 5315701 + 102c758 commit f003761

File tree

5 files changed

+110
-63
lines changed

5 files changed

+110
-63
lines changed

src/extraction/evaluation/benchmark/ground_truth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, path: Path) -> None:
1919
Args:
2020
path (Path): the path to the Ground truth file
2121
"""
22+
self.path = path
2223
self.ground_truth = defaultdict(lambda: defaultdict(dict))
2324

2425
# Load the ground truth data

src/extraction/evaluation/benchmark/score.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from core.mlflow_tracking import mlflow
1616
from extraction.evaluation.benchmark.ground_truth import GroundTruth
17+
from extraction.features.predictions.file_predictions import FilePredictions
1718
from extraction.features.predictions.overall_file_predictions import OverallFilePredictions
1819
from swissgeol_doc_processing.utils.file_utils import get_data_path
1920

@@ -42,26 +43,52 @@ def key(category: str, metric: str) -> str:
4243
return geology_dict | metadata_dict
4344

4445

46+
def evaluate_single_prediction(
47+
prediction: FilePredictions,
48+
ground_truth: GroundTruth | None = None,
49+
) -> FilePredictions:
50+
"""Computes metrics for a given file.
51+
52+
Note that the implementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates
53+
the attributes of `prediction`.
54+
55+
Args:
56+
prediction (FilePredictions): The predictions object.
57+
ground_truth (GroundTruth | None): The ground truth object.
58+
59+
Returns:
60+
FilePredictions: Evaluated prediction.
61+
"""
62+
if ground_truth is None:
63+
return prediction
64+
65+
# Create dummy overall file prediction and append prediction
66+
matched_with_ground_truth = OverallFilePredictions([prediction]).match_with_ground_truth(ground_truth)
67+
68+
# Run evaluation for file (! mutates prediction !)
69+
matched_with_ground_truth.evaluate_geology(verbose=False)
70+
matched_with_ground_truth.evaluate_metadata_extraction()
71+
72+
return prediction
73+
74+
4575
def evaluate_all_predictions(
4676
predictions: OverallFilePredictions,
47-
ground_truth_path: Path,
77+
ground_truth: GroundTruth | None = None,
4878
) -> None | ExtractionBenchmarkSummary:
4979
"""Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).
5080
5181
Args:
5282
predictions (OverallFilePredictions): The predictions objects.
53-
ground_truth_path (Path | None): The path to the ground truth file.
83+
ground_truth (GroundTruth | None): The ground truth object.
5484
5585
Returns:
5686
ExtractionBenchmarkSummary | None: A JSON-serializable ExtractionBenchmarkSummary
5787
that can be used by multi-benchmark runners.
5888
"""
59-
if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available
60-
logger.warning("Ground truth file not found. Skipping evaluation.")
89+
if ground_truth is None:
6190
return None
6291

63-
ground_truth = GroundTruth(ground_truth_path)
64-
6592
#############################
6693
# Evaluate the borehole extraction
6794
#############################
@@ -101,7 +128,7 @@ def evaluate_all_predictions(
101128
mlflow.log_artifact(Path(temp_directory) / "document_level_metadata_metrics.csv")
102129

103130
return ExtractionBenchmarkSummary(
104-
ground_truth_path=str(ground_truth_path),
131+
ground_truth_path=str(ground_truth.path),
105132
n_documents=len(predictions.file_predictions_list),
106133
geology=metrics_dict,
107134
metadata=metadata_metrics.to_json(),
@@ -116,20 +143,27 @@ def main():
116143
try:
117144
with open(args.predictions_path, encoding="utf8") as file:
118145
predictions = json.load(file)
146+
predictions = OverallFilePredictions.from_json(predictions)
119147
except FileNotFoundError:
120148
logger.error("Predictions file not found: %s", args.predictions_path)
121149
return
122150
except json.JSONDecodeError as e:
123151
logger.error("Error decoding JSON from predictions file: %s", e)
124152
return
125153

126-
predictions = OverallFilePredictions.from_json(predictions)
154+
# Load ground truth
155+
try:
156+
ground_truth = GroundTruth(args.ground_truth_path)
157+
except FileNotFoundError:
158+
logger.error("Ground truth file not found: %s", args.ground_truth_path)
159+
return
160+
127161
if mlflow:
128162
mlflow.set_experiment("Boreholes Stratigraphy")
129163
with mlflow.start_run():
130-
evaluate_all_predictions(predictions, args.ground_truth_path)
164+
evaluate_all_predictions(predictions, ground_truth)
131165
else:
132-
evaluate_all_predictions(predictions, args.ground_truth_path)
166+
evaluate_all_predictions(predictions, ground_truth)
133167

134168

135169
def parse_cli() -> argparse.Namespace:

src/extraction/features/predictions/overall_file_predictions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Classes for predictions per PDF file."""
22

3+
import dataclasses
4+
35
from extraction.evaluation.benchmark.ground_truth import GroundTruth
46
from extraction.evaluation.layer_evaluator import LayerEvaluator
57
from extraction.features.metadata.metadata import FileMetadata
@@ -11,12 +13,11 @@
1113
from extraction.features.predictions.predictions import AllBoreholePredictionsWithGroundTruth
1214

1315

16+
@dataclasses.dataclass
1417
class OverallFilePredictions:
1518
"""A class to represent predictions for all files."""
1619

17-
def __init__(self) -> None:
18-
"""Initializes the OverallFilePredictions object."""
19-
self.file_predictions_list: list[FilePredictions] = []
20+
file_predictions_list: list[FilePredictions] = dataclasses.field(default_factory=list)
2021

2122
def contains(self, filename: str) -> bool:
2223
"""Check if `file_predictions_list` contains `filename`.

src/extraction/features/predictions/predictions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,12 @@ def evaluate_metadata_extraction(self) -> OverallBoreholeMetadataMetrics:
287287

288288
return MetadataEvaluator(metadata_list).evaluate()
289289

290-
def evaluate_geology(self) -> OverallMetricsCatalog:
290+
def evaluate_geology(self, verbose: bool = True) -> OverallMetricsCatalog:
291291
"""Evaluate the borehole extraction predictions.
292292
293+
Args:
294+
verbose (bool): If True, log all metrics for the current evaluation.
295+
293296
Returns:
294297
OverallMetricsCatalog: A OverallMetricsCatalog that maps a metrics name to the corresponding
295298
OverallMetrics object. If no ground truth is available, None is returned.
@@ -329,13 +332,14 @@ def evaluate_geology(self) -> OverallMetricsCatalog:
329332
all_metrics, f"{language}_material_description_metrics", evaluator.get_material_description_metrics()
330333
)
331334

332-
logger.info("Macro avg:")
333-
logger.info(
334-
"layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%",
335-
all_metrics.layer_metrics.macro_f1() * 100,
336-
all_metrics.depth_interval_metrics.macro_f1() * 100,
337-
all_metrics.material_description_metrics.macro_f1() * 100,
338-
)
335+
if verbose:
336+
logger.info("Macro avg:")
337+
logger.info(
338+
"layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%",
339+
all_metrics.layer_metrics.macro_f1() * 100,
340+
all_metrics.depth_interval_metrics.macro_f1() * 100,
341+
all_metrics.material_description_metrics.macro_f1() * 100,
342+
)
339343

340344
# TODO groundwater should not be in evaluate_geology(), it should be handle by a higher-level function call
341345
groundwater_list = [

src/extraction/runner.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
)
2020
from core.mlflow_tracking import mlflow
2121
from extraction.core.extract import ExtractionResult, extract, open_pdf
22-
from extraction.evaluation.benchmark.score import ExtractionBenchmarkSummary, evaluate_all_predictions
22+
from extraction.evaluation.benchmark.ground_truth import GroundTruth
23+
from extraction.evaluation.benchmark.score import (
24+
ExtractionBenchmarkSummary,
25+
evaluate_all_predictions,
26+
evaluate_single_prediction,
27+
)
2328
from extraction.evaluation.benchmark.spec import BenchmarkSpec
2429
from extraction.features.predictions.file_predictions import FilePredictions
2530
from extraction.features.predictions.overall_file_predictions import OverallFilePredictions
@@ -298,6 +303,7 @@ def run_predictions(
298303
input_directory: Path,
299304
out_directory: Path,
300305
predictions_path_tmp: Path,
306+
ground_truth: GroundTruth | None = None,
301307
skip_draw_predictions: bool = False,
302308
draw_lines: bool = False,
303309
draw_tables: bool = False,
@@ -318,6 +324,7 @@ def run_predictions(
318324
out_directory (Path): Directory where per-file output (visualizations, CSV) is written.
319325
predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content
320326
is used to resume; the file is updated after each successfully processed file.
327+
ground_truth (GroundTruth | None): Ground truth for evaluation.
321328
skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False.
322329
draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False.
323330
draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False.
@@ -333,54 +340,48 @@ def run_predictions(
333340
the total number of PDF files discovered (including any already-predicted files from a
334341
resumed run), and all CSV file paths written during this run.
335342
"""
336-
if input_directory.is_file():
337-
root = input_directory.parent
338-
pdf_files = [input_directory.name] if input_directory.suffix.lower() == ".pdf" else []
339-
else:
340-
root = input_directory
341-
pdf_files = [f.name for f in input_directory.glob("*.pdf") if f.is_file()]
342-
343+
# Look for files to process
344+
pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.glob("*.pdf"))
343345
n_documents = len(pdf_files)
344346

345347
# Load any partially-completed predictions for resume support
346-
predictions = read_json_predictions(predictions_path_tmp)
348+
predictions = read_json_predictions(str(predictions_path_tmp))
347349

348350
any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs
349351

350352
all_csv_paths: list[Path] = []
351-
352-
for filename in tqdm(pdf_files, desc="Processing files", unit="file"):
353-
if predictions.contains(filename):
354-
logger.info(f"{filename} already predicted.")
353+
for pdf_file in tqdm(pdf_files, desc="Processing files", unit="file"):
354+
# Check if file is already computed in previous run
355+
if predictions.contains(pdf_file.name):
356+
logger.info(f"{pdf_file.name} already predicted.")
355357
continue
356358

357-
in_path = root / filename
358-
logger.info(f"Processing file: {in_path}")
359-
360-
try:
361-
result = extract(file=in_path, filename=in_path.name, part=part, analytics=analytics)
362-
predictions.add_file_predictions(result.predictions)
363-
364-
if csv:
365-
all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory))
366-
367-
if any_draw:
368-
draw_file_predictions(
369-
result=result,
370-
file=in_path,
371-
filename=in_path.name,
372-
out_directory=out_directory,
373-
skip_draw_predictions=skip_draw_predictions,
374-
draw_lines=draw_lines,
375-
draw_tables=draw_tables,
376-
draw_strip_logs=draw_strip_logs,
377-
)
378-
379-
logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}")
380-
write_json_predictions(filename=predictions_path_tmp, predictions=predictions)
381-
382-
except Exception as e:
383-
logger.error(f"Unexpected error in file {filename}. Trace: {e}")
359+
logger.info(f"Processing file: {pdf_file.name}")
360+
361+
# Run extraction and append to predictions
362+
result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics)
363+
predictions.add_file_predictions(result.predictions)
364+
365+
if csv:
366+
all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory))
367+
368+
if any_draw:
369+
# Run evaluation for current file drawing
370+
result.predictions = evaluate_single_prediction(result.predictions, ground_truth)
371+
# Draw predictions for file
372+
draw_file_predictions(
373+
result=result,
374+
file=pdf_file,
375+
filename=pdf_file.name,
376+
out_directory=out_directory,
377+
skip_draw_predictions=skip_draw_predictions,
378+
draw_lines=draw_lines,
379+
draw_tables=draw_tables,
380+
draw_strip_logs=draw_strip_logs,
381+
)
382+
383+
logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}")
384+
write_json_predictions(filename=str(predictions_path_tmp), predictions=predictions)
384385

385386
return predictions, n_documents, all_csv_paths
386387

@@ -444,6 +445,11 @@ def start_pipeline(
444445
delete_temporary(predictions_path_tmp)
445446
delete_temporary(mlflow_runid_tmp)
446447

448+
# Build ground truth
449+
ground_truth: GroundTruth | None = None
450+
if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available
451+
ground_truth = GroundTruth(ground_truth_path)
452+
447453
metadata_path.parent.mkdir(exist_ok=True)
448454

449455
# Initialize analytics if enabled
@@ -469,6 +475,7 @@ def start_pipeline(
469475
input_directory=input_directory,
470476
out_directory=out_directory,
471477
predictions_path_tmp=predictions_path_tmp,
478+
ground_truth=ground_truth,
472479
skip_draw_predictions=skip_draw_predictions,
473480
draw_lines=draw_lines,
474481
draw_tables=draw_tables,
@@ -483,10 +490,10 @@ def start_pipeline(
483490
for csv_path in csv_paths:
484491
mlflow.log_artifact(str(csv_path), "csv")
485492

486-
# Evaluate final predictions
493+
# Evaluate final predictions with all data
487494
eval_summary = evaluate_all_predictions(
488495
predictions=predictions,
489-
ground_truth_path=ground_truth_path,
496+
ground_truth=ground_truth,
490497
)
491498
if eval_summary is not None:
492499
eval_summary.n_documents = n_documents

0 commit comments

Comments
 (0)