|
| 1 | +# Apache Software License 2.0 |
| 2 | +# |
| 3 | +# Copyright (c) ZenML GmbH 2025. All rights reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +"""OCR Evaluation Pipeline implementation for comparing models using existing results.""" |
| 17 | + |
| 18 | +from typing import Any, Dict, List, Optional |
| 19 | + |
| 20 | +import polars as pl |
| 21 | +from dotenv import load_dotenv |
| 22 | +from zenml import pipeline, step |
| 23 | +from zenml.logger import get_logger |
| 24 | + |
| 25 | +from steps import ( |
| 26 | + evaluate_models, |
| 27 | + load_ground_truth_texts, |
| 28 | + load_ocr_results, |
| 29 | + save_visualization, |
| 30 | +) |
| 31 | + |
| 32 | +load_dotenv() |
| 33 | + |
| 34 | +logger = get_logger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +@pipeline |
| 38 | +def ocr_evaluation_pipeline( |
| 39 | + model_names: List[str] = None, |
| 40 | + results_dir: str = "ocr_results", |
| 41 | + result_files: Optional[List[str]] = None, |
| 42 | + ground_truth_folder: Optional[str] = None, |
| 43 | + ground_truth_files: Optional[List[str]] = None, |
| 44 | + save_visualization_data: bool = False, |
| 45 | + visualization_output_dir: str = "visualizations", |
| 46 | +) -> None: |
| 47 | + """Run OCR evaluation pipeline comparing existing model results.""" |
| 48 | + if not model_names or len(model_names) < 2: |
| 49 | + raise ValueError("At least two models are required for comparison") |
| 50 | + |
| 51 | + if not ground_truth_folder and not ground_truth_files: |
| 52 | + raise ValueError( |
| 53 | + "Either ground_truth_folder or ground_truth_files must be provided for evaluation" |
| 54 | + ) |
| 55 | + |
| 56 | + model_results = load_ocr_results( |
| 57 | + model_names=model_names, |
| 58 | + results_dir=results_dir, |
| 59 | + result_files=result_files, |
| 60 | + ) |
| 61 | + |
| 62 | + ground_truth_df = load_ground_truth_texts( |
| 63 | + model_results=model_results, |
| 64 | + ground_truth_folder=ground_truth_folder, |
| 65 | + ground_truth_files=ground_truth_files, |
| 66 | + ) |
| 67 | + |
| 68 | + visualization = evaluate_models( |
| 69 | + model_results=model_results, |
| 70 | + ground_truth_df=ground_truth_df, |
| 71 | + ) |
| 72 | + |
| 73 | + if save_visualization_data: |
| 74 | + save_visualization( |
| 75 | + visualization, |
| 76 | + output_dir=visualization_output_dir, |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +def run_ocr_evaluation_pipeline(config: Dict[str, Any]) -> None: |
| 81 | + """Run the OCR evaluation pipeline from a configuration dictionary. |
| 82 | +
|
| 83 | + Args: |
| 84 | + config: Dictionary containing configuration |
| 85 | +
|
| 86 | + Returns: |
| 87 | + None |
| 88 | + """ |
| 89 | + mode = config.get("parameters", {}).get("mode", "evaluation") |
| 90 | + if mode != "evaluation": |
| 91 | + logger.warning(f"Expected mode 'evaluation', but got '{mode}'. Proceeding anyway.") |
| 92 | + |
| 93 | + selected_models = config.get("parameters", {}).get("selected_models", []) |
| 94 | + if len(selected_models) < 2: |
| 95 | + raise ValueError("At least two models are required for evaluation") |
| 96 | + |
| 97 | + model_registry = config.get("models_registry", []) |
| 98 | + if not model_registry: |
| 99 | + raise ValueError("models_registry is required in the config") |
| 100 | + |
| 101 | + # Get model names from registry by using the passed models (may be shorthands or full names) |
| 102 | + model_names = [] |
| 103 | + shorthand_to_name = { |
| 104 | + m.get("shorthand"): m.get("name") for m in model_registry if "shorthand" in m |
| 105 | + } |
| 106 | + |
| 107 | + for model_id in selected_models: |
| 108 | + if model_id in shorthand_to_name: |
| 109 | + model_names.append(shorthand_to_name[model_id]) |
| 110 | + else: |
| 111 | + if any(m.get("name") == model_id for m in model_registry): |
| 112 | + model_names.append(model_id) |
| 113 | + else: |
| 114 | + logger.warning(f"Model '{model_id}' not found in registry, using as-is") |
| 115 | + model_names.append(model_id) |
| 116 | + |
| 117 | + if len(selected_models) < 2: |
| 118 | + raise ValueError("At least two models are required for evaluation") |
| 119 | + |
| 120 | + # Set up pipeline options |
| 121 | + pipeline_instance = ocr_evaluation_pipeline.with_options( |
| 122 | + enable_cache=config.get("enable_cache", False), |
| 123 | + enable_artifact_metadata=config.get("enable_artifact_metadata", True), |
| 124 | + enable_artifact_visualization=config.get("enable_artifact_visualization", True), |
| 125 | + ) |
| 126 | + |
| 127 | + evaluate_models_params = ( |
| 128 | + config.get("steps", {}).get("evaluate_models", {}).get("parameters", {}) |
| 129 | + ) |
| 130 | + save_visualization_params = ( |
| 131 | + config.get("steps", {}).get("save_visualization", {}).get("parameters", {}) |
| 132 | + ) |
| 133 | + |
| 134 | + pipeline_instance( |
| 135 | + model_names=model_names, |
| 136 | + results_dir=evaluate_models_params.get("results_dir", "ocr_results"), |
| 137 | + result_files=evaluate_models_params.get("result_files"), |
| 138 | + ground_truth_folder=evaluate_models_params.get("ground_truth_folder"), |
| 139 | + ground_truth_files=evaluate_models_params.get("ground_truth_files", []), |
| 140 | + save_visualization_data=save_visualization_params.get("save_locally", False), |
| 141 | + visualization_output_dir=save_visualization_params.get("output_dir", "visualizations"), |
| 142 | + ) |
0 commit comments