From 44519a8546ae7750a41c564c6ecd34fe8dfd418a Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 13:59:40 +0100 Subject: [PATCH 1/8] first iteration fix --- src/extraction/evaluation/benchmark/score.py | 28 ++++++++ src/extraction/runner.py | 75 +++++++++++--------- 2 files changed, 69 insertions(+), 34 deletions(-) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 57772847..35b03cdf 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -14,6 +14,7 @@ from core.mlflow_tracking import mlflow from extraction.evaluation.benchmark.ground_truth import GroundTruth +from extraction.features.predictions.file_predictions import FilePredictions from extraction.features.predictions.overall_file_predictions import OverallFilePredictions from swissgeol_doc_processing.utils.file_utils import get_data_path @@ -42,6 +43,33 @@ def key(category: str, metric: str) -> str: return geology_dict | metadata_dict +def evaluate_single_prediction( + prediction: FilePredictions, + ground_truth: GroundTruth, +) -> FilePredictions: + """Computes metrics for a given file. + + Args: + prediction (FilePredictions): The predictions object. + ground_truth (GroundTruth): The ground truth object. + + Returns: + FilePredictions: Evaluated prediction. + """ + # Create dummy overall file prediction and append prediction + predictions = OverallFilePredictions() + predictions.file_predictions_list.append(prediction) + + # Match file with ground truth + matched_with_ground_truth = predictions.match_with_ground_truth(ground_truth) + + # Run evaluation for file + matched_with_ground_truth.evaluate_geology() + matched_with_ground_truth.evaluate_metadata_extraction() + + return prediction + + def evaluate_all_predictions( predictions: OverallFilePredictions, ground_truth_path: Path, diff --git a/src/extraction/runner.py b/src/extraction/runner.py index 560deb0b..3b2d151a 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -19,7 +19,12 @@ ) from core.mlflow_tracking import mlflow from extraction.core.extract import ExtractionResult, extract, open_pdf -from extraction.evaluation.benchmark.score import ExtractionBenchmarkSummary, evaluate_all_predictions +from extraction.evaluation.benchmark.ground_truth import GroundTruth +from extraction.evaluation.benchmark.score import ( + ExtractionBenchmarkSummary, + evaluate_all_predictions, + evaluate_single_prediction, +) from extraction.evaluation.benchmark.spec import BenchmarkSpec from extraction.features.predictions.file_predictions import FilePredictions from extraction.features.predictions.overall_file_predictions import OverallFilePredictions @@ -298,6 +303,7 @@ def run_predictions( input_directory: Path, out_directory: Path, predictions_path_tmp: Path, + ground_truth_path: Path, skip_draw_predictions: bool = False, draw_lines: bool = False, draw_tables: bool = False, @@ -318,6 +324,7 @@ def run_predictions( out_directory (Path): Directory where per-file output (visualizations, CSV) is written. predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content is used to resume; the file is updated after each successfully processed file. + ground_truth_path (Path): Path to ground truth file for evaluation. skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False. draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False. draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False. @@ -333,54 +340,53 @@ def run_predictions( the total number of PDF files discovered (including any already-predicted files from a resumed run), and all CSV file paths written during this run. """ - if input_directory.is_file(): - root = input_directory.parent - pdf_files = [input_directory.name] if input_directory.suffix.lower() == ".pdf" else [] - else: - root = input_directory - pdf_files = [f.name for f in input_directory.glob("*.pdf") if f.is_file()] - + # Look for files to process + pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.rglob("*.pdf")) n_documents = len(pdf_files) + # Build ground truth + ground_truth: GroundTruth = None + if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available + ground_truth = GroundTruth(ground_truth_path) + # Load any partially-completed predictions for resume support predictions = read_json_predictions(predictions_path_tmp) any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs all_csv_paths: list[Path] = [] - - for filename in tqdm(pdf_files, desc="Processing files", unit="file"): - if predictions.contains(filename): - logger.info(f"{filename} already predicted.") + for pdf_file in tqdm(pdf_files, desc="Processing files", unit="file"): + # Check if file is already computed in previous run + if predictions.contains(pdf_file.name): + logger.info(f"{pdf_file.name} already predicted.") continue - in_path = root / filename - logger.info(f"Processing file: {in_path}") + logger.info(f"Processing file: {pdf_file.name}") - try: - result = extract(file=in_path, filename=in_path.name, part=part, analytics=analytics) - predictions.add_file_predictions(result.predictions) + # Run prediction on file and evaluate it + result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics) + if ground_truth: + result.predictions = evaluate_single_prediction(result.predictions, ground_truth) - if csv: - all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory)) + predictions.add_file_predictions(result.predictions) - if any_draw: - draw_file_predictions( - result=result, - file=in_path, - filename=in_path.name, - out_directory=out_directory, - skip_draw_predictions=skip_draw_predictions, - draw_lines=draw_lines, - draw_tables=draw_tables, - draw_strip_logs=draw_strip_logs, - ) + if csv: + all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory)) - logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}") - write_json_predictions(filename=predictions_path_tmp, predictions=predictions) + if any_draw: + draw_file_predictions( + result=result, + file=pdf_file, + filename=pdf_file.name, + out_directory=out_directory, + skip_draw_predictions=skip_draw_predictions, + draw_lines=draw_lines, + draw_tables=draw_tables, + draw_strip_logs=draw_strip_logs, + ) - except Exception as e: - logger.error(f"Unexpected error in file {filename}. Trace: {e}") + logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}") + write_json_predictions(filename=predictions_path_tmp, predictions=predictions) return predictions, n_documents, all_csv_paths @@ -469,6 +475,7 @@ def start_pipeline( input_directory=input_directory, out_directory=out_directory, predictions_path_tmp=predictions_path_tmp, + ground_truth_path=ground_truth_path, skip_draw_predictions=skip_draw_predictions, draw_lines=draw_lines, draw_tables=draw_tables, From 20fe2cef197b09fd825a1b1193c03d9225b74d96 Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 14:15:55 +0100 Subject: [PATCH 2/8] load ground truth once --- src/extraction/evaluation/benchmark/score.py | 18 +++++++-------- src/extraction/runner.py | 24 +++++++++----------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 35b03cdf..5e1fbae9 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -26,7 +26,6 @@ class ExtractionBenchmarkSummary(BaseModel): """Helper class containing a summary of all the results of a single benchmark.""" - ground_truth_path: str n_documents: int geology: dict[str, float] metadata: dict[str, float] @@ -45,17 +44,20 @@ def key(category: str, metric: str) -> str: def evaluate_single_prediction( prediction: FilePredictions, - ground_truth: GroundTruth, + ground_truth: GroundTruth | None = None, ) -> FilePredictions: """Computes metrics for a given file. Args: prediction (FilePredictions): The predictions object. - ground_truth (GroundTruth): The ground truth object. + ground_truth (GroundTruth | None): The ground truth object. Returns: FilePredictions: Evaluated prediction. """ + if not ground_truth: + return prediction + # Create dummy overall file prediction and append prediction predictions = OverallFilePredictions() predictions.file_predictions_list.append(prediction) @@ -72,24 +74,21 @@ def evaluate_single_prediction( def evaluate_all_predictions( predictions: OverallFilePredictions, - ground_truth_path: Path, + ground_truth: GroundTruth | None = None, ) -> None | ExtractionBenchmarkSummary: """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled). Args: predictions (OverallFilePredictions): The predictions objects. - ground_truth_path (Path | None): The path to the ground truth file. + ground_truth (GroundTruth | None): The ground truth object. Returns: ExtractionBenchmarkSummary | None: A JSON-serializable ExtractionBenchmarkSummary that can be used by multi-benchmark runners. """ - if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available - logger.warning("Ground truth file not found. Skipping evaluation.") + if not ground_truth: return None - ground_truth = GroundTruth(ground_truth_path) - ############################# # Evaluate the borehole extraction ############################# @@ -129,7 +128,6 @@ def evaluate_all_predictions( mlflow.log_artifact(Path(temp_directory) / "document_level_metadata_metrics.csv") return ExtractionBenchmarkSummary( - ground_truth_path=str(ground_truth_path), n_documents=len(predictions.file_predictions_list), geology=metrics_dict, metadata=metadata_metrics.to_json(), diff --git a/src/extraction/runner.py b/src/extraction/runner.py index 3b2d151a..a6e6f0c0 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -303,7 +303,7 @@ def run_predictions( input_directory: Path, out_directory: Path, predictions_path_tmp: Path, - ground_truth_path: Path, + ground_truth: GroundTruth | None = None, skip_draw_predictions: bool = False, draw_lines: bool = False, draw_tables: bool = False, @@ -324,7 +324,7 @@ def run_predictions( out_directory (Path): Directory where per-file output (visualizations, CSV) is written. predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content is used to resume; the file is updated after each successfully processed file. - ground_truth_path (Path): Path to ground truth file for evaluation. + ground_truth (GroundTruth | None): Ground truth for evaluation. skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False. draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False. draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False. @@ -344,11 +344,6 @@ def run_predictions( pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.rglob("*.pdf")) n_documents = len(pdf_files) - # Build ground truth - ground_truth: GroundTruth = None - if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available - ground_truth = GroundTruth(ground_truth_path) - # Load any partially-completed predictions for resume support predictions = read_json_predictions(predictions_path_tmp) @@ -365,9 +360,7 @@ def run_predictions( # Run prediction on file and evaluate it result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics) - if ground_truth: - result.predictions = evaluate_single_prediction(result.predictions, ground_truth) - + result.predictions = evaluate_single_prediction(result.predictions, ground_truth) predictions.add_file_predictions(result.predictions) if csv: @@ -450,6 +443,11 @@ def start_pipeline( delete_temporary(predictions_path_tmp) delete_temporary(mlflow_runid_tmp) + # Build ground truth + ground_truth: GroundTruth = None + if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available + ground_truth = GroundTruth(ground_truth_path) + metadata_path.parent.mkdir(exist_ok=True) # Initialize analytics if enabled @@ -475,7 +473,7 @@ def start_pipeline( input_directory=input_directory, out_directory=out_directory, predictions_path_tmp=predictions_path_tmp, - ground_truth_path=ground_truth_path, + ground_truth=ground_truth, skip_draw_predictions=skip_draw_predictions, draw_lines=draw_lines, draw_tables=draw_tables, @@ -490,10 +488,10 @@ def start_pipeline( for csv_path in csv_paths: mlflow.log_artifact(str(csv_path), "csv") - # Evaluate final predictions + # Evaluate final predictions with all data eval_summary = evaluate_all_predictions( predictions=predictions, - ground_truth_path=ground_truth_path, + ground_truth=ground_truth, ) if eval_summary is not None: eval_summary.n_documents = n_documents From 052ada27c950838d9912bc00de14045a1b3bc96e Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 14:34:23 +0100 Subject: [PATCH 3/8] cleaning first iteration --- .../evaluation/benchmark/ground_truth.py | 1 + src/extraction/evaluation/benchmark/score.py | 18 +++++++++++++----- src/extraction/runner.py | 4 ++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/extraction/evaluation/benchmark/ground_truth.py b/src/extraction/evaluation/benchmark/ground_truth.py index 30b199e7..78610603 100644 --- a/src/extraction/evaluation/benchmark/ground_truth.py +++ b/src/extraction/evaluation/benchmark/ground_truth.py @@ -19,6 +19,7 @@ def __init__(self, path: Path) -> None: Args: path (Path): the path to the Ground truth file """ + self.path = path self.ground_truth = defaultdict(lambda: defaultdict(dict)) # Load the ground truth data diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 5e1fbae9..f7eb12f5 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -55,7 +55,7 @@ def evaluate_single_prediction( Returns: FilePredictions: Evaluated prediction. """ - if not ground_truth: + if ground_truth is None: return prediction # Create dummy overall file prediction and append prediction @@ -86,7 +86,7 @@ def evaluate_all_predictions( ExtractionBenchmarkSummary | None: A JSON-serializable ExtractionBenchmarkSummary that can be used by multi-benchmark runners. """ - if not ground_truth: + if ground_truth is None: return None ############################# @@ -128,6 +128,7 @@ def evaluate_all_predictions( mlflow.log_artifact(Path(temp_directory) / "document_level_metadata_metrics.csv") return ExtractionBenchmarkSummary( + ground_truth_path=str(ground_truth.path), n_documents=len(predictions.file_predictions_list), geology=metrics_dict, metadata=metadata_metrics.to_json(), @@ -142,6 +143,7 @@ def main(): try: with open(args.predictions_path, encoding="utf8") as file: predictions = json.load(file) + predictions = OverallFilePredictions.from_json(predictions) except FileNotFoundError: logger.error("Predictions file not found: %s", args.predictions_path) return @@ -149,13 +151,19 @@ def main(): logger.error("Error decoding JSON from predictions file: %s", e) return - predictions = OverallFilePredictions.from_json(predictions) + # Load ground truth + try: + ground_truth = GroundTruth(args.ground_truth_path) + except FileNotFoundError: + logger.error("Ground truth file not found: %s", args.ground_truth_path) + return + if mlflow: mlflow.set_experiment("Boreholes Stratigraphy") with mlflow.start_run(): - evaluate_all_predictions(predictions, args.ground_truth_path) + evaluate_all_predictions(predictions, ground_truth) else: - evaluate_all_predictions(predictions, args.ground_truth_path) + evaluate_all_predictions(predictions, ground_truth) def parse_cli() -> argparse.Namespace: diff --git a/src/extraction/runner.py b/src/extraction/runner.py index a6e6f0c0..3dd61ac0 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -341,7 +341,7 @@ def run_predictions( resumed run), and all CSV file paths written during this run. """ # Look for files to process - pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.rglob("*.pdf")) + pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.glob("*.pdf")) n_documents = len(pdf_files) # Load any partially-completed predictions for resume support @@ -444,7 +444,7 @@ def start_pipeline( delete_temporary(mlflow_runid_tmp) # Build ground truth - ground_truth: GroundTruth = None + ground_truth: GroundTruth | None = None if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available ground_truth = GroundTruth(ground_truth_path) From 93896f827dcc3e5f29251509c3a95cb843c00396 Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 14:41:43 +0100 Subject: [PATCH 4/8] typo --- src/extraction/evaluation/benchmark/score.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index f7eb12f5..63548459 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -26,6 +26,7 @@ class ExtractionBenchmarkSummary(BaseModel): """Helper class containing a summary of all the results of a single benchmark.""" + ground_truth_path: str n_documents: int geology: dict[str, float] metadata: dict[str, float] From b2401ead78f2789e0952200e4075255b2d77400c Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 15:17:42 +0100 Subject: [PATCH 5/8] indent --- src/extraction/runner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/extraction/runner.py b/src/extraction/runner.py index 3dd61ac0..474d985d 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -101,7 +101,7 @@ def write_json_predictions(filename: str, predictions: OverallFilePredictions) - predictions (OverallFilePredictions): Prediction to dump in JSON file. """ with open(filename, "w", encoding="utf8") as file: - json.dump(predictions.to_json(), file, ensure_ascii=False) + json.dump(predictions.to_json(), file, ensure_ascii=False, indent=2) def read_json_predictions(filename: str) -> OverallFilePredictions: @@ -358,15 +358,17 @@ def run_predictions( logger.info(f"Processing file: {pdf_file.name}") - # Run prediction on file and evaluate it + # Run extraction and happend to predictions result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics) - result.predictions = evaluate_single_prediction(result.predictions, ground_truth) predictions.add_file_predictions(result.predictions) if csv: all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory)) if any_draw: + # Run evaluation for current file drawing + result.predictions = evaluate_single_prediction(result.predictions, ground_truth) + # Draw predictions for file draw_file_predictions( result=result, file=pdf_file, From 9a5f34b5f7073f90950a835b02af95d934f51cc9 Mon Sep 17 00:00:00 2001 From: christianabbet Date: Wed, 18 Mar 2026 15:54:28 +0100 Subject: [PATCH 6/8] update typos before final test --- src/extraction/evaluation/benchmark/score.py | 5 ++++- src/extraction/runner.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 63548459..86696e12 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -49,6 +49,9 @@ def evaluate_single_prediction( ) -> FilePredictions: """Computes metrics for a given file. + Note that the impolementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates + the attributes of `prediction`. + Args: prediction (FilePredictions): The predictions object. ground_truth (GroundTruth | None): The ground truth object. @@ -66,7 +69,7 @@ def evaluate_single_prediction( # Match file with ground truth matched_with_ground_truth = predictions.match_with_ground_truth(ground_truth) - # Run evaluation for file + # Run evaluation for file (! mutates prediction !) matched_with_ground_truth.evaluate_geology() matched_with_ground_truth.evaluate_metadata_extraction() diff --git a/src/extraction/runner.py b/src/extraction/runner.py index 474d985d..93e2c1bc 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -345,7 +345,7 @@ def run_predictions( n_documents = len(pdf_files) # Load any partially-completed predictions for resume support - predictions = read_json_predictions(predictions_path_tmp) + predictions = read_json_predictions(str(predictions_path_tmp)) any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs @@ -358,7 +358,7 @@ def run_predictions( logger.info(f"Processing file: {pdf_file.name}") - # Run extraction and happend to predictions + # Run extraction and append to predictions result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics) predictions.add_file_predictions(result.predictions) @@ -381,7 +381,7 @@ def run_predictions( ) logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}") - write_json_predictions(filename=predictions_path_tmp, predictions=predictions) + write_json_predictions(filename=str(predictions_path_tmp), predictions=predictions) return predictions, n_documents, all_csv_paths From 4df2d617f7ff22913f3c931c5d8560e40adf3f0f Mon Sep 17 00:00:00 2001 From: christianabbet Date: Tue, 24 Mar 2026 11:33:55 +0100 Subject: [PATCH 7/8] typo fix --- src/extraction/evaluation/benchmark/score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 86696e12..c7fc955f 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -49,7 +49,7 @@ def evaluate_single_prediction( ) -> FilePredictions: """Computes metrics for a given file. - Note that the impolementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates + Note that the implementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates the attributes of `prediction`. Args: From 102c75888994266607551ce16fb27236396829ee Mon Sep 17 00:00:00 2001 From: christianabbet Date: Fri, 27 Mar 2026 11:20:39 +0100 Subject: [PATCH 8/8] addressing comments on single file evaluation and logging --- src/extraction/evaluation/benchmark/score.py | 8 ++------ .../predictions/overall_file_predictions.py | 7 ++++--- .../features/predictions/predictions.py | 20 +++++++++++-------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index c7fc955f..5489deab 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -63,14 +63,10 @@ def evaluate_single_prediction( return prediction # Create dummy overall file prediction and append prediction - predictions = OverallFilePredictions() - predictions.file_predictions_list.append(prediction) - - # Match file with ground truth - matched_with_ground_truth = predictions.match_with_ground_truth(ground_truth) + matched_with_ground_truth = OverallFilePredictions([prediction]).match_with_ground_truth(ground_truth) # Run evaluation for file (! mutates prediction !) - matched_with_ground_truth.evaluate_geology() + matched_with_ground_truth.evaluate_geology(verbose=False) matched_with_ground_truth.evaluate_metadata_extraction() return prediction diff --git a/src/extraction/features/predictions/overall_file_predictions.py b/src/extraction/features/predictions/overall_file_predictions.py index 51bab22b..7a62325f 100644 --- a/src/extraction/features/predictions/overall_file_predictions.py +++ b/src/extraction/features/predictions/overall_file_predictions.py @@ -1,5 +1,7 @@ """Classes for predictions per PDF file.""" +import dataclasses + from extraction.evaluation.benchmark.ground_truth import GroundTruth from extraction.evaluation.layer_evaluator import LayerEvaluator from extraction.features.metadata.metadata import FileMetadata @@ -11,12 +13,11 @@ from extraction.features.predictions.predictions import AllBoreholePredictionsWithGroundTruth +@dataclasses.dataclass class OverallFilePredictions: """A class to represent predictions for all files.""" - def __init__(self) -> None: - """Initializes the OverallFilePredictions object.""" - self.file_predictions_list: list[FilePredictions] = [] + file_predictions_list: list[FilePredictions] = dataclasses.field(default_factory=list) def contains(self, filename: str) -> bool: """Check if `file_predictions_list` contains `filename`. diff --git a/src/extraction/features/predictions/predictions.py b/src/extraction/features/predictions/predictions.py index b492f0db..50a5b8b7 100644 --- a/src/extraction/features/predictions/predictions.py +++ b/src/extraction/features/predictions/predictions.py @@ -274,9 +274,12 @@ def evaluate_metadata_extraction(self) -> OverallBoreholeMetadataMetrics: return MetadataEvaluator(metadata_list).evaluate() - def evaluate_geology(self) -> OverallMetricsCatalog: + def evaluate_geology(self, verbose: bool = True) -> OverallMetricsCatalog: """Evaluate the borehole extraction predictions. + Args: + verbose (bool): If True, log all metrics for the current evaluation. + Returns: OverallMetricsCatalog: A OverallMetricsCatalog that maps a metrics name to the corresponding OverallMetrics object. If no ground truth is available, None is returned. @@ -316,13 +319,14 @@ def evaluate_geology(self) -> OverallMetricsCatalog: all_metrics, f"{language}_material_description_metrics", evaluator.get_material_description_metrics() ) - logger.info("Macro avg:") - logger.info( - "layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%", - all_metrics.layer_metrics.macro_f1() * 100, - all_metrics.depth_interval_metrics.macro_f1() * 100, - all_metrics.material_description_metrics.macro_f1() * 100, - ) + if verbose: + logger.info("Macro avg:") + logger.info( + "layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%", + all_metrics.layer_metrics.macro_f1() * 100, + all_metrics.depth_interval_metrics.macro_f1() * 100, + all_metrics.material_description_metrics.macro_f1() * 100, + ) # TODO groundwater should not be in evaluate_geology(), it should be handle by a higher-level function call groundwater_list = [