diff --git a/src/extraction/evaluation/benchmark/ground_truth.py b/src/extraction/evaluation/benchmark/ground_truth.py index 30b199e7..78610603 100644 --- a/src/extraction/evaluation/benchmark/ground_truth.py +++ b/src/extraction/evaluation/benchmark/ground_truth.py @@ -19,6 +19,7 @@ def __init__(self, path: Path) -> None: Args: path (Path): the path to the Ground truth file """ + self.path = path self.ground_truth = defaultdict(lambda: defaultdict(dict)) # Load the ground truth data diff --git a/src/extraction/evaluation/benchmark/score.py b/src/extraction/evaluation/benchmark/score.py index 57772847..5489deab 100644 --- a/src/extraction/evaluation/benchmark/score.py +++ b/src/extraction/evaluation/benchmark/score.py @@ -14,6 +14,7 @@ from core.mlflow_tracking import mlflow from extraction.evaluation.benchmark.ground_truth import GroundTruth +from extraction.features.predictions.file_predictions import FilePredictions from extraction.features.predictions.overall_file_predictions import OverallFilePredictions from swissgeol_doc_processing.utils.file_utils import get_data_path @@ -42,26 +43,52 @@ def key(category: str, metric: str) -> str: return geology_dict | metadata_dict +def evaluate_single_prediction( + prediction: FilePredictions, + ground_truth: GroundTruth | None = None, +) -> FilePredictions: + """Computes metrics for a given file. + + Note that the implementation of `evaluate_geology` and `evaluate_metadata_extraction` mutates + the attributes of `prediction`. + + Args: + prediction (FilePredictions): The predictions object. + ground_truth (GroundTruth | None): The ground truth object. + + Returns: + FilePredictions: Evaluated prediction. + """ + if ground_truth is None: + return prediction + + # Create dummy overall file prediction and append prediction + matched_with_ground_truth = OverallFilePredictions([prediction]).match_with_ground_truth(ground_truth) + + # Run evaluation for file (! mutates prediction !) + matched_with_ground_truth.evaluate_geology(verbose=False) + matched_with_ground_truth.evaluate_metadata_extraction() + + return prediction + + def evaluate_all_predictions( predictions: OverallFilePredictions, - ground_truth_path: Path, + ground_truth: GroundTruth | None = None, ) -> None | ExtractionBenchmarkSummary: """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled). Args: predictions (OverallFilePredictions): The predictions objects. - ground_truth_path (Path | None): The path to the ground truth file. + ground_truth (GroundTruth | None): The ground truth object. Returns: ExtractionBenchmarkSummary | None: A JSON-serializable ExtractionBenchmarkSummary that can be used by multi-benchmark runners. """ - if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available - logger.warning("Ground truth file not found. Skipping evaluation.") + if ground_truth is None: return None - ground_truth = GroundTruth(ground_truth_path) - ############################# # Evaluate the borehole extraction ############################# @@ -101,7 +128,7 @@ def evaluate_all_predictions( mlflow.log_artifact(Path(temp_directory) / "document_level_metadata_metrics.csv") return ExtractionBenchmarkSummary( - ground_truth_path=str(ground_truth_path), + ground_truth_path=str(ground_truth.path), n_documents=len(predictions.file_predictions_list), geology=metrics_dict, metadata=metadata_metrics.to_json(), @@ -116,6 +143,7 @@ def main(): try: with open(args.predictions_path, encoding="utf8") as file: predictions = json.load(file) + predictions = OverallFilePredictions.from_json(predictions) except FileNotFoundError: logger.error("Predictions file not found: %s", args.predictions_path) return @@ -123,13 +151,19 @@ def main(): logger.error("Error decoding JSON from predictions file: %s", e) return - predictions = OverallFilePredictions.from_json(predictions) + # Load ground truth + try: + ground_truth = GroundTruth(args.ground_truth_path) + except FileNotFoundError: + logger.error("Ground truth file not found: %s", args.ground_truth_path) + return + if mlflow: mlflow.set_experiment("Boreholes Stratigraphy") with mlflow.start_run(): - evaluate_all_predictions(predictions, args.ground_truth_path) + evaluate_all_predictions(predictions, ground_truth) else: - evaluate_all_predictions(predictions, args.ground_truth_path) + evaluate_all_predictions(predictions, ground_truth) def parse_cli() -> argparse.Namespace: diff --git a/src/extraction/features/predictions/overall_file_predictions.py b/src/extraction/features/predictions/overall_file_predictions.py index 51bab22b..7a62325f 100644 --- a/src/extraction/features/predictions/overall_file_predictions.py +++ b/src/extraction/features/predictions/overall_file_predictions.py @@ -1,5 +1,7 @@ """Classes for predictions per PDF file.""" +import dataclasses + from extraction.evaluation.benchmark.ground_truth import GroundTruth from extraction.evaluation.layer_evaluator import LayerEvaluator from extraction.features.metadata.metadata import FileMetadata @@ -11,12 +13,11 @@ from extraction.features.predictions.predictions import AllBoreholePredictionsWithGroundTruth +@dataclasses.dataclass class OverallFilePredictions: """A class to represent predictions for all files.""" - def __init__(self) -> None: - """Initializes the OverallFilePredictions object.""" - self.file_predictions_list: list[FilePredictions] = [] + file_predictions_list: list[FilePredictions] = dataclasses.field(default_factory=list) def contains(self, filename: str) -> bool: """Check if `file_predictions_list` contains `filename`. diff --git a/src/extraction/features/predictions/predictions.py b/src/extraction/features/predictions/predictions.py index b492f0db..50a5b8b7 100644 --- a/src/extraction/features/predictions/predictions.py +++ b/src/extraction/features/predictions/predictions.py @@ -274,9 +274,12 @@ def evaluate_metadata_extraction(self) -> OverallBoreholeMetadataMetrics: return MetadataEvaluator(metadata_list).evaluate() - def evaluate_geology(self) -> OverallMetricsCatalog: + def evaluate_geology(self, verbose: bool = True) -> OverallMetricsCatalog: """Evaluate the borehole extraction predictions. + Args: + verbose (bool): If True, log all metrics for the current evaluation. + Returns: OverallMetricsCatalog: A OverallMetricsCatalog that maps a metrics name to the corresponding OverallMetrics object. If no ground truth is available, None is returned. @@ -316,13 +319,14 @@ def evaluate_geology(self) -> OverallMetricsCatalog: all_metrics, f"{language}_material_description_metrics", evaluator.get_material_description_metrics() ) - logger.info("Macro avg:") - logger.info( - "layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%", - all_metrics.layer_metrics.macro_f1() * 100, - all_metrics.depth_interval_metrics.macro_f1() * 100, - all_metrics.material_description_metrics.macro_f1() * 100, - ) + if verbose: + logger.info("Macro avg:") + logger.info( + "layer f1: %.1f%%, depth interval f1: %.1f%%, material description f1: %.1f%%", + all_metrics.layer_metrics.macro_f1() * 100, + all_metrics.depth_interval_metrics.macro_f1() * 100, + all_metrics.material_description_metrics.macro_f1() * 100, + ) # TODO groundwater should not be in evaluate_geology(), it should be handle by a higher-level function call groundwater_list = [ diff --git a/src/extraction/runner.py b/src/extraction/runner.py index 560deb0b..93e2c1bc 100644 --- a/src/extraction/runner.py +++ b/src/extraction/runner.py @@ -19,7 +19,12 @@ ) from core.mlflow_tracking import mlflow from extraction.core.extract import ExtractionResult, extract, open_pdf -from extraction.evaluation.benchmark.score import ExtractionBenchmarkSummary, evaluate_all_predictions +from extraction.evaluation.benchmark.ground_truth import GroundTruth +from extraction.evaluation.benchmark.score import ( + ExtractionBenchmarkSummary, + evaluate_all_predictions, + evaluate_single_prediction, +) from extraction.evaluation.benchmark.spec import BenchmarkSpec from extraction.features.predictions.file_predictions import FilePredictions from extraction.features.predictions.overall_file_predictions import OverallFilePredictions @@ -96,7 +101,7 @@ def write_json_predictions(filename: str, predictions: OverallFilePredictions) - predictions (OverallFilePredictions): Prediction to dump in JSON file. """ with open(filename, "w", encoding="utf8") as file: - json.dump(predictions.to_json(), file, ensure_ascii=False) + json.dump(predictions.to_json(), file, ensure_ascii=False, indent=2) def read_json_predictions(filename: str) -> OverallFilePredictions: @@ -298,6 +303,7 @@ def run_predictions( input_directory: Path, out_directory: Path, predictions_path_tmp: Path, + ground_truth: GroundTruth | None = None, skip_draw_predictions: bool = False, draw_lines: bool = False, draw_tables: bool = False, @@ -318,6 +324,7 @@ def run_predictions( out_directory (Path): Directory where per-file output (visualizations, CSV) is written. predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content is used to resume; the file is updated after each successfully processed file. + ground_truth (GroundTruth | None): Ground truth for evaluation. skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False. draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False. draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False. @@ -333,54 +340,48 @@ def run_predictions( the total number of PDF files discovered (including any already-predicted files from a resumed run), and all CSV file paths written during this run. """ - if input_directory.is_file(): - root = input_directory.parent - pdf_files = [input_directory.name] if input_directory.suffix.lower() == ".pdf" else [] - else: - root = input_directory - pdf_files = [f.name for f in input_directory.glob("*.pdf") if f.is_file()] - + # Look for files to process + pdf_files = [input_directory] if input_directory.is_file() else list(input_directory.glob("*.pdf")) n_documents = len(pdf_files) # Load any partially-completed predictions for resume support - predictions = read_json_predictions(predictions_path_tmp) + predictions = read_json_predictions(str(predictions_path_tmp)) any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs all_csv_paths: list[Path] = [] - - for filename in tqdm(pdf_files, desc="Processing files", unit="file"): - if predictions.contains(filename): - logger.info(f"{filename} already predicted.") + for pdf_file in tqdm(pdf_files, desc="Processing files", unit="file"): + # Check if file is already computed in previous run + if predictions.contains(pdf_file.name): + logger.info(f"{pdf_file.name} already predicted.") continue - in_path = root / filename - logger.info(f"Processing file: {in_path}") - - try: - result = extract(file=in_path, filename=in_path.name, part=part, analytics=analytics) - predictions.add_file_predictions(result.predictions) - - if csv: - all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory)) - - if any_draw: - draw_file_predictions( - result=result, - file=in_path, - filename=in_path.name, - out_directory=out_directory, - skip_draw_predictions=skip_draw_predictions, - draw_lines=draw_lines, - draw_tables=draw_tables, - draw_strip_logs=draw_strip_logs, - ) - - logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}") - write_json_predictions(filename=predictions_path_tmp, predictions=predictions) - - except Exception as e: - logger.error(f"Unexpected error in file {filename}. Trace: {e}") + logger.info(f"Processing file: {pdf_file.name}") + + # Run extraction and append to predictions + result = extract(file=pdf_file, filename=pdf_file.name, part=part, analytics=analytics) + predictions.add_file_predictions(result.predictions) + + if csv: + all_csv_paths.extend(write_csv_for_file(result.predictions, out_directory)) + + if any_draw: + # Run evaluation for current file drawing + result.predictions = evaluate_single_prediction(result.predictions, ground_truth) + # Draw predictions for file + draw_file_predictions( + result=result, + file=pdf_file, + filename=pdf_file.name, + out_directory=out_directory, + skip_draw_predictions=skip_draw_predictions, + draw_lines=draw_lines, + draw_tables=draw_tables, + draw_strip_logs=draw_strip_logs, + ) + + logger.info(f"Writing predictions to tmp JSON file {predictions_path_tmp}") + write_json_predictions(filename=str(predictions_path_tmp), predictions=predictions) return predictions, n_documents, all_csv_paths @@ -444,6 +445,11 @@ def start_pipeline( delete_temporary(predictions_path_tmp) delete_temporary(mlflow_runid_tmp) + # Build ground truth + ground_truth: GroundTruth | None = None + if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available + ground_truth = GroundTruth(ground_truth_path) + metadata_path.parent.mkdir(exist_ok=True) # Initialize analytics if enabled @@ -469,6 +475,7 @@ def start_pipeline( input_directory=input_directory, out_directory=out_directory, predictions_path_tmp=predictions_path_tmp, + ground_truth=ground_truth, skip_draw_predictions=skip_draw_predictions, draw_lines=draw_lines, draw_tables=draw_tables, @@ -483,10 +490,10 @@ def start_pipeline( for csv_path in csv_paths: mlflow.log_artifact(str(csv_path), "csv") - # Evaluate final predictions + # Evaluate final predictions with all data eval_summary = evaluate_all_predictions( predictions=predictions, - ground_truth_path=ground_truth_path, + ground_truth=ground_truth, ) if eval_summary is not None: eval_summary.n_documents = n_documents