Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/extraction/evaluation/benchmark/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, path: Path) -> None:
Args:
path (Path): the path to the Ground truth file
"""
self.path = path
self.ground_truth = defaultdict(lambda: defaultdict(dict))

# Load the ground truth data
Expand Down
54 changes: 44 additions & 10 deletions src/extraction/evaluation/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from core.mlflow_tracking import mlflow
from extraction.evaluation.benchmark.ground_truth import GroundTruth
from extraction.features.predictions.file_predictions import FilePredictions
from extraction.features.predictions.overall_file_predictions import OverallFilePredictions
from swissgeol_doc_processing.utils.file_utils import get_data_path

Expand Down Expand Up @@ -42,26 +43,52 @@ def key(category: str, metric: str) -> str:
return geology_dict | metadata_dict


def evaluate_single_prediction(
prediction: FilePredictions,
ground_truth: GroundTruth | None = None,
) -> FilePredictions:
"""Computes metrics for a given file.

Note that the implementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates
the attributes of `prediction`.

Args:
prediction (FilePredictions): The predictions object.
ground_truth (GroundTruth | None): The ground truth object.

Returns:
FilePredictions: Evaluated prediction.
"""
if ground_truth is None:
return prediction

# Create dummy overall file prediction and append prediction
matched_with_ground_truth = OverallFilePredictions([prediction]).match_with_ground_truth(ground_truth)

# Run evaluation for file (! mutates prediction !)
matched_with_ground_truth.evaluate_geology(verbose=False)
matched_with_ground_truth.evaluate_metadata_extraction()

return prediction


def evaluate_all_predictions(
predictions: OverallFilePredictions,
ground_truth_path: Path,
ground_truth: GroundTruth | None = None,
) -> None | ExtractionBenchmarkSummary:
"""Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).

Args:
predictions (OverallFilePredictions): The predictions objects.
ground_truth_path (Path | None): The path to the ground truth file.
ground_truth (GroundTruth | None): The ground truth object.

Returns:
ExtractionBenchmarkSummary | None: A JSON-serializable ExtractionBenchmarkSummary
that can be used by multi-benchmark runners.
"""
if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available
logger.warning("Ground truth file not found. Skipping evaluation.")
if ground_truth is None:
return None

ground_truth = GroundTruth(ground_truth_path)

#############################
# Evaluate the borehole extraction
#############################
Expand Down Expand Up @@ -101,7 +128,7 @@ def evaluate_all_predictions(
mlflow.log_artifact(Path(temp_directory) / "document_level_metadata_metrics.csv")

return ExtractionBenchmarkSummary(
ground_truth_path=str(ground_truth_path),
ground_truth_path=str(ground_truth.path),
n_documents=len(predictions.file_predictions_list),
geology=metrics_dict,
metadata=metadata_metrics.to_json(),
Expand All @@ -116,20 +143,27 @@ def main():
try:
with open(args.predictions_path, encoding="utf8") as file:
predictions = json.load(file)
predictions = OverallFilePredictions.from_json(predictions)
except FileNotFoundError:
logger.error("Predictions file not found: %s", args.predictions_path)
return
except json.JSONDecodeError as e:
logger.error("Error decoding JSON from predictions file: %s", e)
return

predictions = OverallFilePredictions.from_json(predictions)
# Load ground truth
try:
ground_truth = GroundTruth(args.ground_truth_path)
except FileNotFoundError:
logger.error("Ground truth file not found: %s", args.ground_truth_path)
return

if mlflow:
mlflow.set_experiment("Boreholes Stratigraphy")
with mlflow.start_run():
evaluate_all_predictions(predictions, args.ground_truth_path)
evaluate_all_predictions(predictions, ground_truth)
else:
evaluate_all_predictions(predictions, args.ground_truth_path)
evaluate_all_predictions(predictions, ground_truth)


def parse_cli() -> argparse.Namespace:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Classes for predictions per PDF file."""

import dataclasses

from extraction.evaluation.benchmark.ground_truth import GroundTruth
from extraction.evaluation.layer_evaluator import LayerEvaluator
from extraction.features.metadata.metadata import FileMetadata
Expand All @@ -11,12 +13,11 @@
from extraction.features.predictions.predictions import AllBoreholePredictionsWithGroundTruth


@dataclasses.dataclass
class OverallFilePredictions:
"""A class to represent predictions for all files."""

def __init__(self) -> None:
"""Initializes the OverallFilePredictions object."""
self.file_predictions_list: list[FilePredictions] = []
file_predictions_list: list[FilePredictions] = dataclasses.field(default_factory=list)

def contains(self, filename: str) -> bool:
"""Check if `file_predictions_list` contains `filename`.
Expand Down
20 changes: 12 additions & 8 deletions src/extraction/features/predictions/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,12 @@ def evaluate_metadata_extraction(self) -> OverallBoreholeMetadataMetrics:

return MetadataEvaluator(metadata_list).evaluate()

def evaluate_geology(self) -> OverallMetricsCatalog:
def evaluate_geology(self, verbose: bool = True) -> OverallMetricsCatalog:
"""Evaluate the borehole extraction predictions.

Args:
verbose (bool): If True, log all metrics for the current evaluation.

Returns:
OverallMetricsCatalog: A OverallMetricsCatalog that maps a metrics name to the corresponding
OverallMetrics object. If no ground truth is available, None is returned.
Expand Down Expand Up @@ -316,13 +319,14 @@ def evaluate_geology(self) -> OverallMetricsCatalog:
all_metrics, f"{language}_material_description_metrics", evaluator.get_material_description_metrics()
)

logger.info("Macro avg:")
logger.info(
"layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%",
all_metrics.layer_metrics.macro_f1() * 100,
all_metrics.depth_interval_metrics.macro_f1() * 100,
all_metrics.material_description_metrics.macro_f1() * 100,
)
if verbose:
logger.info("Macro avg:")
logger.info(
"layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%",
all_metrics.layer_metrics.macro_f1() * 100,
all_metrics.depth_interval_metrics.macro_f1() * 100,
all_metrics.material_description_metrics.macro_f1() * 100,
)

# TODO groundwater should not be in evaluate_geology(), it should be handle by a higher-level function call
groundwater_list = [
Expand Down
93 changes: 50 additions & 43 deletions src/extraction/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
)
from core.mlflow_tracking import mlflow
from extraction.core.extract import ExtractionResult, extract, open_pdf
from extraction.evaluation.benchmark.score import ExtractionBenchmarkSummary, evaluate_all_predictions
from extraction.evaluation.benchmark.ground_truth import GroundTruth
from extraction.evaluation.benchmark.score import (
ExtractionBenchmarkSummary,
evaluate_all_predictions,
evaluate_single_prediction,
)
from extraction.evaluation.benchmark.spec import BenchmarkSpec
from extraction.features.predictions.file_predictions import FilePredictions
from extraction.features.predictions.overall_file_predictions import OverallFilePredictions
Expand Down Expand Up @@ -96,7 +101,7 @@ def write_json_predictions(filename: str, predictions: OverallFilePredictions) -
predictions (OverallFilePredictions): Prediction to dump in JSON file.
"""
with open(filename, "w", encoding="utf8") as file:
json.dump(predictions.to_json(), file, ensure_ascii=False)
json.dump(predictions.to_json(), file, ensure_ascii=False, indent=2)


def read_json_predictions(filename: str) -> OverallFilePredictions:
Expand Down Expand Up @@ -298,6 +303,7 @@ def run_predictions(
input_directory: Path,
out_directory: Path,
predictions_path_tmp: Path,
ground_truth: GroundTruth | None = None,
skip_draw_predictions: bool = False,
draw_lines: bool = False,
draw_tables: bool = False,
Expand All @@ -318,6 +324,7 @@ def run_predictions(
out_directory (Path): Directory where per-file output (visualizations, CSV) is written.
predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content
is used to resume; the file is updated after each successfully processed file.
ground_truth (GroundTruth | None): Ground truth for evaluation.
skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False.
draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False.
draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False.
Expand All @@ -333,54 +340,48 @@ def run_predictions(
the total number of PDF files discovered (including any already-predicted files from a
resumed run), and all CSV file paths written during this run.
"""
if input_directory.is_file():
root = input_directory.parent
pdf_files = [input_directory.name] if input_directory.suffix.lower() == ".pdf" else []
else:
root = input_directory
pdf_files = [f.name for f in input_directory.glob("*.pdf") if f.is_file()]

# Look for files to process
pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.glob("*.pdf"))
n_documents = len(pdf_files)

# Load any partially-completed predictions for resume support
predictions = read_json_predictions(predictions_path_tmp)
predictions = read_json_predictions(str(predictions_path_tmp))

any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs

all_csv_paths: list[Path] = []

for filename in tqdm(pdf_files, desc="Processing files", unit="file"):
if predictions.contains(filename):
logger.info(f"{filename} already predicted.")
for pdf_file in tqdm(pdf_files, desc="Processing files", unit="file"):
# Check if file is already computed in previous run
if predictions.contains(pdf_file.name):
logger.info(f"{pdf_file.name} already predicted.")
continue

in_path = root / filename
logger.info(f"Processing file: {in_path}")

try:
result = extract(file=in_path, filename=in_path.name, part=part, analytics=analytics)
predictions.add_file_predictions(result.predictions)

if csv:
all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory))

if any_draw:
draw_file_predictions(
result=result,
file=in_path,
filename=in_path.name,
out_directory=out_directory,
skip_draw_predictions=skip_draw_predictions,
draw_lines=draw_lines,
draw_tables=draw_tables,
draw_strip_logs=draw_strip_logs,
)

logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}")
write_json_predictions(filename=predictions_path_tmp, predictions=predictions)

except Exception as e:
logger.error(f"Unexpected error in file {filename}. Trace: {e}")
logger.info(f"Processing file: {pdf_file.name}")

# Run extraction and append to predictions
result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics)
predictions.add_file_predictions(result.predictions)

if csv:
all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory))

if any_draw:
# Run evaluation for current file drawing
result.predictions = evaluate_single_prediction(result.predictions, ground_truth)
# Draw predictions for file
draw_file_predictions(
result=result,
file=pdf_file,
filename=pdf_file.name,
out_directory=out_directory,
skip_draw_predictions=skip_draw_predictions,
draw_lines=draw_lines,
draw_tables=draw_tables,
draw_strip_logs=draw_strip_logs,
)

logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}")
write_json_predictions(filename=str(predictions_path_tmp), predictions=predictions)

return predictions, n_documents, all_csv_paths

Expand Down Expand Up @@ -444,6 +445,11 @@ def start_pipeline(
delete_temporary(predictions_path_tmp)
delete_temporary(mlflow_runid_tmp)

# Build ground truth
ground_truth: GroundTruth | None = None
if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available
ground_truth = GroundTruth(ground_truth_path)

metadata_path.parent.mkdir(exist_ok=True)

# Initialize analytics if enabled
Expand All @@ -469,6 +475,7 @@ def start_pipeline(
input_directory=input_directory,
out_directory=out_directory,
predictions_path_tmp=predictions_path_tmp,
ground_truth=ground_truth,
skip_draw_predictions=skip_draw_predictions,
draw_lines=draw_lines,
draw_tables=draw_tables,
Expand All @@ -483,10 +490,10 @@ def start_pipeline(
for csv_path in csv_paths:
mlflow.log_artifact(str(csv_path), "csv")

# Evaluate final predictions
# Evaluate final predictions with all data
eval_summary = evaluate_all_predictions(
predictions=predictions,
ground_truth_path=ground_truth_path,
ground_truth=ground_truth,
)
if eval_summary is not None:
eval_summary.n_documents = n_documents
Expand Down
Loading