1919)
2020from core .mlflow_tracking import mlflow
2121from extraction .core .extract import ExtractionResult , extract , open_pdf
22- from extraction .evaluation .benchmark .score import ExtractionBenchmarkSummary , evaluate_all_predictions
22+ from extraction .evaluation .benchmark .ground_truth import GroundTruth
23+ from extraction .evaluation .benchmark .score import (
24+ ExtractionBenchmarkSummary ,
25+ evaluate_all_predictions ,
26+ evaluate_single_prediction ,
27+ )
2328from extraction .evaluation .benchmark .spec import BenchmarkSpec
2429from extraction .features .predictions .file_predictions import FilePredictions
2530from extraction .features .predictions .overall_file_predictions import OverallFilePredictions
@@ -298,6 +303,7 @@ def run_predictions(
298303 input_directory : Path ,
299304 out_directory : Path ,
300305 predictions_path_tmp : Path ,
306+ ground_truth : GroundTruth | None = None ,
301307 skip_draw_predictions : bool = False ,
302308 draw_lines : bool = False ,
303309 draw_tables : bool = False ,
@@ -318,6 +324,7 @@ def run_predictions(
318324 out_directory (Path): Directory where per-file output (visualizations, CSV) is written.
319325 predictions_path_tmp (Path): Path to the incremental tmp predictions file. Existing content
320326 is used to resume; the file is updated after each successfully processed file.
327+ ground_truth (GroundTruth | None): Ground truth for evaluation.
321328 skip_draw_predictions (bool, optional): Skip drawing predictions on PDF pages. Defaults to False.
322329 draw_lines (bool, optional): Draw detected lines on PDF pages. Defaults to False.
323330 draw_tables (bool, optional): Draw detected table structures on PDF pages. Defaults to False.
@@ -333,54 +340,48 @@ def run_predictions(
333340 the total number of PDF files discovered (including any already-predicted files from a
334341 resumed run), and all CSV file paths written during this run.
335342 """
336- if input_directory .is_file ():
337- root = input_directory .parent
338- pdf_files = [input_directory .name ] if input_directory .suffix .lower () == ".pdf" else []
339- else :
340- root = input_directory
341- pdf_files = [f .name for f in input_directory .glob ("*.pdf" ) if f .is_file ()]
342-
343+ # Look for files to process
344+ pdf_files = [input_directory ] if input_directory .is_file () else list (input_directory .glob ("*.pdf" ))
343345 n_documents = len (pdf_files )
344346
345347 # Load any partially-completed predictions for resume support
346- predictions = read_json_predictions (predictions_path_tmp )
348+ predictions = read_json_predictions (str ( predictions_path_tmp ) )
347349
348350 any_draw = not skip_draw_predictions or draw_lines or draw_tables or draw_strip_logs
349351
350352 all_csv_paths : list [Path ] = []
351-
352- for filename in tqdm ( pdf_files , desc = "Processing files" , unit = " file" ):
353- if predictions .contains (filename ):
354- logger .info (f"{ filename } already predicted." )
353+ for pdf_file in tqdm ( pdf_files , desc = "Processing files" , unit = "file" ):
354+ # Check if file is already computed in previous run
355+ if predictions .contains (pdf_file . name ):
356+ logger .info (f"{ pdf_file . name } already predicted." )
355357 continue
356358
357- in_path = root / filename
358- logger .info (f"Processing file: { in_path } " )
359-
360- try :
361- result = extract (file = in_path , filename = in_path .name , part = part , analytics = analytics )
362- predictions .add_file_predictions (result .predictions )
363-
364- if csv :
365- all_csv_paths .extend (write_csv_for_file (result .predictions , out_directory ))
366-
367- if any_draw :
368- draw_file_predictions (
369- result = result ,
370- file = in_path ,
371- filename = in_path .name ,
372- out_directory = out_directory ,
373- skip_draw_predictions = skip_draw_predictions ,
374- draw_lines = draw_lines ,
375- draw_tables = draw_tables ,
376- draw_strip_logs = draw_strip_logs ,
377- )
378-
379- logger .info (f"Writing predictions to tmp JSON file { predictions_path_tmp } " )
380- write_json_predictions (filename = predictions_path_tmp , predictions = predictions )
381-
382- except Exception as e :
383- logger .error (f"Unexpected error in file { filename } . Trace: { e } " )
359+ logger .info (f"Processing file: { pdf_file .name } " )
360+
361+ # Run extraction and append to predictions
362+ result = extract (file = pdf_file , filename = pdf_file .name , part = part , analytics = analytics )
363+ predictions .add_file_predictions (result .predictions )
364+
365+ if csv :
366+ all_csv_paths .extend (write_csv_for_file (result .predictions , out_directory ))
367+
368+ if any_draw :
369+ # Run evaluation for current file drawing
370+ result .predictions = evaluate_single_prediction (result .predictions , ground_truth )
371+ # Draw predictions for file
372+ draw_file_predictions (
373+ result = result ,
374+ file = pdf_file ,
375+ filename = pdf_file .name ,
376+ out_directory = out_directory ,
377+ skip_draw_predictions = skip_draw_predictions ,
378+ draw_lines = draw_lines ,
379+ draw_tables = draw_tables ,
380+ draw_strip_logs = draw_strip_logs ,
381+ )
382+
383+ logger .info (f"Writing predictions to tmp JSON file { predictions_path_tmp } " )
384+ write_json_predictions (filename = str (predictions_path_tmp ), predictions = predictions )
384385
385386 return predictions , n_documents , all_csv_paths
386387
@@ -444,6 +445,11 @@ def start_pipeline(
444445 delete_temporary (predictions_path_tmp )
445446 delete_temporary (mlflow_runid_tmp )
446447
448+ # Build ground truth
449+ ground_truth : GroundTruth | None = None
450+ if ground_truth_path and ground_truth_path .exists (): # for inference no ground truth is available
451+ ground_truth = GroundTruth (ground_truth_path )
452+
447453 metadata_path .parent .mkdir (exist_ok = True )
448454
449455 # Initialize analytics if enabled
@@ -469,6 +475,7 @@ def start_pipeline(
469475 input_directory = input_directory ,
470476 out_directory = out_directory ,
471477 predictions_path_tmp = predictions_path_tmp ,
478+ ground_truth = ground_truth ,
472479 skip_draw_predictions = skip_draw_predictions ,
473480 draw_lines = draw_lines ,
474481 draw_tables = draw_tables ,
@@ -483,10 +490,10 @@ def start_pipeline(
483490 for csv_path in csv_paths :
484491 mlflow .log_artifact (str (csv_path ), "csv" )
485492
486- # Evaluate final predictions
493+ # Evaluate final predictions with all data
487494 eval_summary = evaluate_all_predictions (
488495 predictions = predictions ,
489- ground_truth_path = ground_truth_path ,
496+ ground_truth = ground_truth ,
490497 )
491498 if eval_summary is not None :
492499 eval_summary .n_documents = n_documents
0 commit comments