diff --git a/.typos.toml b/.typos.toml index 2fa346b5..abd36cef 100644 --- a/.typos.toml +++ b/.typos.toml @@ -48,6 +48,8 @@ huggingface = "huggingface" answerdotai = "answerdotai" preprocessor = "preprocessor" logits = "logits" +analyse = "analyse" +Labour = "Labour" [default] -locale = "en-us" \ No newline at end of file +locale = "en-us" diff --git a/omni-reader/.dockerignore b/omni-reader/.dockerignore new file mode 100644 index 00000000..615ec5e0 --- /dev/null +++ b/omni-reader/.dockerignore @@ -0,0 +1,5 @@ +.venv* +.requirements* +__pycache__/ +*.py[cod] +*$py.class diff --git a/omni-reader/.env.example b/omni-reader/.env.example new file mode 100644 index 00000000..287559b1 --- /dev/null +++ b/omni-reader/.env.example @@ -0,0 +1,3 @@ +OPENAI_API_KEY=your_openai_api_key +MISTRAL_API_KEY=your_mistral_api_key +OLLAMA_HOST=base_url_for_ollama_host # defaults to "http://localhost:11434/api/generate" if not set diff --git a/omni-reader/.gitignore b/omni-reader/.gitignore new file mode 100644 index 00000000..8959be6f --- /dev/null +++ b/omni-reader/.gitignore @@ -0,0 +1,104 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +env/ +.env +.venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +.project +.pydevproject +.settings/ +*.sublime-workspace +*.sublime-project + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb + +# Model checkpoints and evaluation results +eval_results/ +runs/ +wandb/ +*.pt +*.pth +*.ckpt +*.bin +model_outputs/ + +# Data +*.csv +*.tsv +!requirements.txt + +# Logs +*.log +logs/ +tensorboard_logs/ +lightning_logs/ + +# OS +.DS_Store +Thumbs.db +*.db +*.sqlite +*.sqlite3 + +# Environment variables +.env +.env.local +.env.*.local + +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.tox/ +coverage.xml +*.cover +.hypothesis/ + +# Documentation +docs/_build/ +site/ + +# Misc +*.bak +*.tmp +*.temp +.cache +.zen/ + +# Extras +checkpoints/ +ruff.toml +models/ +cache/ +zenml_docs/ +CLAUDE.md diff --git a/omni-reader/LICENSE b/omni-reader/LICENSE new file mode 100644 index 00000000..8c0f9c55 --- /dev/null +++ b/omni-reader/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2025 ZenML GmbH + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/omni-reader/README.md b/omni-reader/README.md new file mode 100644 index 00000000..a9233ac0 --- /dev/null +++ b/omni-reader/README.md @@ -0,0 +1,304 @@ +# OmniReader + +A scalable multi-model text extraction solution for unstructured documents. + +
+ Pipeline DAG +
+ +✨ **Extract Structured Text from Any Document** +OmniReader is built for teams who routinely work with unstructured documents (e.g., PDFs, images, scanned forms) and want a scalable workflow for structured text extraction. It provides an end-to-end batch OCR pipeline with optional multi-model comparison to help ML engineers evaluate different OCR solutions before deployment. + +
+ HTML Visualization of OCR Results +

HTML visualization showing metrics and comparison results from the OCR pipeline

+
+ +## 🔮 Use Cases + +- **Document Processing Automation**: Extract structured data from invoices, receipts, and forms +- **Content Digitization**: Convert scanned documents and books into searchable digital content +- **Regulatory Compliance**: Extract and validate information from compliance documents +- **Data Migration**: Convert legacy paper documents into structured digital formats +- **Research & Analysis**: Extract data from academic papers, reports, and publications + +## 🌟 Key Features + +- **End-to-end workflow management** from evaluation to production deployment +- **Multi-model comparison** to identify the best model for your specific document types +- **Scalable batch processing** that can handle enterprise document volumes +- **Quantitative evaluation metrics** to inform business and technical decisions +- **ZenML integration** providing reproducibility, cloud-agnostic deployment, and monitoring + +## 🎭 How It Works + +OmniReader provides two primary pipeline workflows that can be run separately: + +1. **Batch OCR Pipeline**: Run large batches of documents through a single model to extract structured text and metadata. +2. **Evaluation Pipeline**: Compare multiple OCR models side-by-side and generate evaluation reports using CER/WER and HTML visualizations against ground truth text files. + +Behind the scenes, OmniReader leverages state-of-the-art vision-language models and ZenML's MLOps framework to create a reproducible, scalable document processing system. + +## 📚 Supported Models + +OmniReader supports a wide range of OCR models, including: + +- **Mistral/pixtral-12b-2409**: Mistral AI's vision-language model specializing in document understanding with strong OCR capabilities for complex layouts. +- **GPT-4o-mini**: OpenAI's efficient vision model offering a good balance of accuracy and speed for general document processing tasks. +- **Gemma3:27b**: Google's open-source multimodal model supporting 140+ languages with a 128K context window, optimized for text extraction from diverse document types. +- **Llava:34b**: Large multilingual vision-language model with strong performance on document understanding tasks requiring contextual interpretation. +- **Llava-phi3**: Microsoft's efficient multimodal model combining phi-3 language capabilities with vision understanding, ideal for mixed text-image documents. +- **Granite3.2-vision**: Specialized for visual document understanding, offering excellent performance on tables, charts, and technical diagrams. + +> ⚠️ Note: For production deployments, we recommend using the non-GGUF hosted model versions via their respective APIs for better performance and accuracy. The Ollama models mentioned here are primarily for convenience. + +### 🔧 OCR Processor Configuration + +OmniReader supports multiple OCR processors to handle different models: + +1. **litellm**: For using LiteLLM-compatible models including those from Mistral and other providers. + +- Set API keys for your providers (e.g., `MISTRAL_API_KEY`) +- **Important**: When using `litellm` as the processor, you must specify the `provider` field in your model configuration. + +2. **ollama**: For running local models through Ollama. + + - Requires: [Ollama](https://ollama.com/) installed and running + - Set `OLLAMA_HOST` (defaults to "http://localhost:11434/api/generate") + - If using local models, they must be pulled before use with `ollama pull model_name` + +3. **openai**: For using OpenAI models like GPT-4o. + - Set `OPENAI_API_KEY` environment variable + +Example model configurations in your `configs/batch_pipeline.yaml`: + +```yaml +models_registry: + - name: "gpt-4o-mini" + shorthand: "gpt4o" + ocr_processor: "openai" + # No provider needed for OpenAI + + - name: "gemma3:27b" + shorthand: "gemma3" + ocr_processor: "ollama" + # No provider needed for Ollama + + - name: "mistral/pixtral-12b-2409" + shorthand: "pixtral" + ocr_processor: "litellm" + provider: "mistral" # Provider field required for litellm processor +``` + +To add your own models, extend the `models_registry` with the appropriate processor and provider configurations based on the model source. + +## 🛠️ Project Structure + +``` +omni-reader/ +│ +├── app.py # Streamlit UI for interactive document processing +├── assets/ # Sample images for ocr +├── configs/ # YAML configuration files +├── ground_truth_texts/ # Text files containing ground truth for evaluation +├── pipelines/ # ZenML pipeline definitions +│ ├── batch_pipeline.py # Batch OCR pipeline (single or multiple models) +│ └── evaluation_pipeline.py # Evaluation pipeline (multiple models) +├── steps/ # Pipeline step implementations +│ ├── evaluate_models.py # Model comparison and metrics +│ ├── loaders.py # Loading images and ground truth texts +│ └── run_ocr.py # Running OCR with selected models +├── utils/ # Utility functions and helpers +│ ├── ocr_processing.py # OCR processing core logic +│ ├── metrics.py # Metrics for evaluation +│ ├── visualizations.py # Visualization utilities for the evaluation pipeline +│ ├── encode_image.py # Image encoding utilities for OCR processing +│ ├── prompt.py # Prompt template for vision models +│ ├── config.py # Utilities for loading and validating configs +│ └── model_configs.py # Model configuration and registry +├── run.py # Main entrypoint for running the pipeline +└── README.md # Project documentation +``` + +## 🚀 Getting Started + +### Prerequisites + +- Python 3.9+ +- Mistral API key (set as environment variable `MISTRAL_API_KEY`) +- OpenAI API key (set as environment variable `OPENAI_API_KEY`) +- ZenML >= 0.80.0 +- Ollama (required for running local models) + +### Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Start Ollama (if using local models) +ollama serve +``` + +### Set Up Your Environment + +Configure your API keys: + +```bash +export OPENAI_API_KEY=your_openai_api_key +export MISTRAL_API_KEY=your_mistral_api_key +export OLLAMA_HOST=base_url_for_ollama_host # defaults to "http://localhost:11434/api/generate" if not set +``` + +### Run OmniReader + +```bash +# Run the batch pipeline (default) +python run.py + +# Run the evaluation pipeline +python run.py --eval + +# Run with a custom config file +python run.py --config my_custom_config.yaml + +# Run with custom input +python run.py --image-folder ./my_images + +# List ground truth files +python run.py --list-ground-truth-files +``` + +### Interactive UI + +The project also includes a Streamlit app that allows you to: + +- Upload documents for instant OCR processing +- Compare results from multiple models side-by-side +- Experiment with custom prompts to improve extraction quality + +```bash +# Launch the Streamlit interface +streamlit run app.py +``` + +
+ Model Comparison Results +

Side-by-side comparison of OCR results across different models

+
+ +## ☁️ Cloud Deployment + +OmniReader supports storing artifacts remotely and executing pipelines on cloud infrastructure. For this example, we'll use AWS, but you can use any cloud provider you want. You can also refer to the [AWS Integration Guide](https://docs.zenml.io/how-to/popular-integrations/aws-guide) for detailed instructions. + +### AWS Setup + +1. **Install required integrations**: + + ```bash + zenml integration install aws s3 + ``` + +2. **Set up your AWS credentials**: + + - Create an IAM role with appropriate permissions (S3, ECR, SageMaker) + - Configure your role ARN and region + +3. **Register an AWS service connector**: + + ```bash + zenml service-connector register aws_connector \ + --type aws \ + --auth-method iam-role \ + --role_arn= \ + --region= \ + --aws_access_key_id= \ + --aws_secret_access_key= + ``` + +4. **Configure stack components**: + + a. **S3 Artifact Store**: + + ```bash + zenml artifact-store register s3_artifact_store \ + -f s3 \ + --path=s3:// \ + --connector aws_connector + ``` + + b. **SageMaker Orchestrator**: + + ```bash + zenml orchestrator register sagemaker_orchestrator \ + --flavor=sagemaker \ + --region= \ + --execution_role= + ``` + + c. **ECR Container Registry**: + + ```bash + zenml container-registry register ecr_registry \ + --flavor=aws \ + --uri=.dkr.ecr..amazonaws.com \ + --connector aws_connector + ``` + +5. **Register and activate your stack**: + ```bash + zenml stack register aws_stack \ + -a s3_artifact_store \ + -o sagemaker_orchestrator \ + -c ecr_registry \ + --set + ``` + +### Other Cloud Providers + +Similar setup processes can be followed for other cloud providers: + +- **Azure**: Install the Azure integration (`zenml integration install azure`) and set up Azure Blob Storage, AzureML, and Azure Container Registry +- **Google Cloud**: Install the GCP integration (`zenml integration install gcp gcs`) and set up GCS, Vertex AI, and GCR +- **Kubernetes**: Install the Kubernetes integration (`zenml integration install kubernetes`) and set up a Kubernetes cluster + +For detailed configuration options for these providers, refer to the ZenML documentation: + +- [GCP Integration Guide](https://docs.zenml.io/how-to/popular-integrations/gcp-guide) +- [Azure Integration Guide](https://docs.zenml.io/how-to/popular-integrations/azure-guide) +- [Kubernetes Integration Guide](https://docs.zenml.io/how-to/popular-integrations/kubernetes) + +### 🐳 Docker Settings for Cloud Deployment + +For cloud execution, you'll need to configure Docker settings in your pipeline: + +```python +from zenml.config import DockerSettings + +# Create Docker settings +docker_settings = DockerSettings( + required_integrations=["aws", "s3"], # Based on your cloud provider + requirements="requirements.txt", + python_package_installer="uv", # Optional, defaults to "pip" + environment={ + "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), + "MISTRAL_API_KEY": os.getenv("MISTRAL_API_KEY"), + }, +) + +# Use in your pipeline definition +@pipeline(settings={"docker": docker_settings}) +def batch_ocr_pipeline(...): + ... +``` + +## 📚 Documentation + +For more information about ZenML and building MLOps pipelines, refer to the [ZenML documentation](https://docs.zenml.io/). + +For model-specific documentation: + +- [Mistral AI Vision Documentation](https://docs.mistral.ai/capabilities/vision/) +- [LiteLLM Providers Documentation](https://docs.litellm.ai/docs/providers) +- [Gemma3 Documentation](https://ai.google.dev/gemma/docs/integrations/ollama) +- [Ollama Models Library](https://ollama.com/library) diff --git a/omni-reader/app.py b/omni-reader/app.py new file mode 100644 index 00000000..61a99cf7 --- /dev/null +++ b/omni-reader/app.py @@ -0,0 +1,562 @@ +"""This is the Streamlit UI for the OCR Extraction workflow.""" + +import base64 +import os +import time + +import streamlit as st +from PIL import Image + +from utils.model_configs import DEFAULT_MODEL, MODEL_CONFIGS +from utils.ocr_processing import run_ocr + + +def setup_page_config(): + """Configure Streamlit page settings.""" + st.set_page_config( + page_title="OCR Extraction", + page_icon="🔎", + layout="wide", + initial_sidebar_state="expanded", + ) + + +def load_model_logos(): + """Load model logos from assets folder.""" + logo_mapping = {} + logos_dir = "./assets/logos" + processed_displays = set() + + for _, model_config in MODEL_CONFIGS.items(): + if model_config.display in processed_displays: + continue + + processed_displays.add(model_config.display) + logo_filename = model_config.logo + + if os.path.exists(os.path.join(logos_dir, logo_filename)): + try: + logo_mapping[model_config.display] = base64.b64encode( + open(os.path.join(logos_dir, logo_filename), "rb").read() + ).decode() + except Exception as e: + print(f"Error loading logo for {model_config.display}: {e}") + provider = model_config.shorthand + if os.path.exists(os.path.join(logos_dir, f"{provider}.svg")): + logo_mapping[model_config.display] = base64.b64encode( + open(os.path.join(logos_dir, f"{provider}.svg"), "rb").read() + ).decode() + + return logo_mapping + + +def render_header(model_logos=None): + """Render the page header based on selected models.""" + if st.session_state.get("comparison_mode", False): + selected_models = st.session_state.get("comparison_models", [DEFAULT_MODEL.name]) + + if len(selected_models) > 1: + st.title(f"OCR Model Comparison ({len(selected_models)} Models)") + else: + model_display = MODEL_CONFIGS[selected_models[0]].display + st.title(f"{model_display} OCR") + else: + selected_model_id = st.session_state.get("selected_model_id", DEFAULT_MODEL.name) + selected_model_display = MODEL_CONFIGS[selected_model_id].display + + if model_logos and selected_model_display in model_logos: + logo_html = f'# {selected_model_display} OCR' + st.markdown(logo_html, unsafe_allow_html=True) + else: + st.title(f"{selected_model_display} OCR") + + +def render_sidebar(): + """Render the sidebar with model selection and settings.""" + with st.sidebar: + st.header("Settings") + + comparison_mode = st.radio( + "Mode", + ["Single Model", "Compare Models"], + horizontal=True, + ) + + # Get unique models (avoid duplicates from shorthand keys) + unique_models = { + model_id: model_config + for model_id, model_config in MODEL_CONFIGS.items() + if model_id == model_config.name + } + + if comparison_mode == "Single Model": + model_options = [ + model_config.display for model_id, model_config in unique_models.items() + ] + + # Find default model index + default_index = 0 + for i, option in enumerate(model_options): + if option == DEFAULT_MODEL.display: + default_index = i + break + + model_choice = st.selectbox( + "Select OCR Model", + options=model_options, + index=default_index, + key="model_selector_single", + ) + + # Map display name back to model ID + selected_model_id = None + for model_id, model_config in unique_models.items(): + if model_config.display == model_choice: + selected_model_id = model_id + break + + st.session_state["comparison_mode"] = False + st.session_state["selected_model_id"] = selected_model_id + else: + # Multiselect for model comparison + st.subheader("Models to Compare") + model_display_to_id = { + model_config.display: model_id for model_id, model_config in unique_models.items() + } + default_models = [DEFAULT_MODEL.display] + previously_selected = st.session_state.get("selected_model_displays", default_models) + multiselect_key = f"model_multiselect_{len(previously_selected)}" + + selected_model_displays = st.multiselect( + "Select models to compare", + options=list(model_display_to_id.keys()), + default=previously_selected, + key=multiselect_key, + ) + + if not selected_model_displays: + st.warning("At least one model must be selected. Using default model.") + selected_model_displays = default_models + + if selected_model_displays != previously_selected: + st.session_state["selected_model_displays"] = selected_model_displays + st.rerun() + + comparison_models = [ + model_display_to_id[display] for display in selected_model_displays + ] + + st.session_state["comparison_mode"] = True + st.session_state["comparison_models"] = comparison_models + + st.header("Custom Prompt") + use_custom_prompt = st.checkbox("Use custom prompt", value=False) + custom_prompt = ( + st.text_area( + "Enter your custom prompt", + "Extract and analyze all text from the image and identify any objects present.", + height=100, + ) + if use_custom_prompt + else "" + ) + + st.header("Upload Image") + uploaded_file = st.file_uploader( + "Choose an image...", + type=["png", "jpg", "jpeg", "webp", "tiff"], + key=st.session_state.get("file_uploader_key", "default_uploader"), + ) + + # Store uploaded file in session state + if uploaded_file is not None: + st.session_state["uploaded_file"] = uploaded_file + + st.markdown("---") + display_provider_info() + + return st.session_state.get( + "uploaded_file", uploaded_file + ), custom_prompt if use_custom_prompt else None + + +def display_provider_info(): + """Display information about which providers are being used.""" + providers = set() + for model_id, model_config in MODEL_CONFIGS.items(): + if model_id == model_config.name: # Skip shorthand duplicates + providers.add(model_config.ocr_processor) + + provider_names = {"ollama": "Ollama", "openai": "OpenAI", "mistral": "MistralAI"} + provider_text = " + ".join([provider_names.get(p, p.capitalize()) for p in providers]) + st.caption(f"Powered by {provider_text}") + + +def add_clear_button(): + """Add a clear button to reset the app state.""" + col1, col2 = st.columns([6, 1]) + with col2: + if st.button("Clear 🗑️"): + keys_to_clear = ["ocr_result", "uploaded_file", "processed_image"] + for model_id in MODEL_CONFIGS: + result_key = f"{model_id}_result" + keys_to_clear.append(result_key) + + for key in keys_to_clear: + if key in st.session_state: + del st.session_state[key] + + # Force file uploader to reset + st.session_state["file_uploader_key"] = str(time.time()) + st.rerun() + + +def process_uploaded_image(image): + """Process and display the uploaded image.""" + img_width, img_height = image.size + max_width, max_height = 600, 500 + + # Calculate scaling factor to maintain aspect ratio + width_ratio = max_width / img_width + height_ratio = max_height / img_height + scale_factor = min(width_ratio, height_ratio) + + # Only resize if needed + if img_width > max_width or img_height > max_height: + new_width = int(img_width * scale_factor) + new_height = int(img_height * scale_factor) + image = image.resize((new_width, new_height)) + + # Store the processed image in session state + st.session_state["processed_image"] = image + st.image(image, caption="Uploaded Image", use_container_width=False) + return image + + +def check_for_error(result): + """Check if the OCR model has an error.""" + if "error" in result: + return True + + raw_text = result.get("raw_text", "") + if not isinstance(raw_text, str): + raw_text = str(raw_text) + + return raw_text.startswith("Error:") or ("success" in result and result["success"] is False) + + +def has_no_text(result): + """Check if the OCR model has no text.""" + raw_text = result.get("raw_text", "") + if not isinstance(raw_text, str): + raw_text = str(raw_text) + return raw_text == "No text found" + + +def display_result(label, result, proc_time, model_id, model_logos): + """Display the results of the OCR model.""" + model_config = MODEL_CONFIGS[model_id] + model_display = model_config.display + + # Display header with logo if available + if model_display in model_logos: + header_html = f'

{label} Results

' + st.markdown(header_html, unsafe_allow_html=True) + else: + st.subheader(f"{label} Results") + + st.text(f"Processing time: {proc_time:.2f}s") + st.markdown("##### Extracted Text") + + text = result.get("raw_text", "") + if not isinstance(text, str): + text = str(text) + + if text.startswith("Error:"): + st.error(text) + elif text == "No text found": + st.warning("No text found in the image") + else: + st.write(text) + + +def run_single_model(image, model_id, custom_prompt): + """Run OCR with a single model and display results.""" + selected_model_display = MODEL_CONFIGS[model_id].display + model_logos = load_model_logos() + + with st.spinner(f"Processing image with {selected_model_display}..."): + try: + start = time.time() + + # Use the unified run_ocr function + result = run_ocr( + image_input=image, + model_ids=model_id, + custom_prompt=custom_prompt, + track_metadata=False, + ) + + proc_time = time.time() - start + + # Store result in session state + st.session_state[f"{model_id}_result"] = result + + # Display results + display_result(selected_model_display, result, proc_time, model_id, model_logos) + + # Show additional stats + st.subheader("Processing Stats") + st.text(f"Processing time: {proc_time:.2f}s") + st.text(f"Text length: {len(result['raw_text'])} characters") + + if "confidence" in result and result["confidence"] is not None: + st.text(f"Confidence: {result['confidence']:.2%}") + + st.text(f"Provider: {MODEL_CONFIGS[model_id].ocr_processor.capitalize()}") + + return result, proc_time + + except Exception as e: + st.error(f"Error processing image: {e}") + return None, 0 + + +def display_comparison_stats(model_ids, model_results, model_times, model_errors): + """Display comparison statistics for multiple models.""" + st.markdown("### Comparison Stats") + + # Determine fastest model + valid_times = [] + for model_id in model_ids: + if not model_errors[model_id]: + valid_times.append((MODEL_CONFIGS[model_id].display, model_times[model_id])) + + if valid_times: + fastest = min(valid_times, key=lambda x: x[1]) + st.write(f"🚀 Fastest model: **{fastest[0]}** ({fastest[1]:.2f}s)") + + # Show error status for each model + error_status = [] + for model_id in model_ids: + if model_errors[model_id]: + error_status.append(f"{MODEL_CONFIGS[model_id].display} failed") + + if error_status: + st.write("⚠️ " + ", ".join(error_status)) + + # Compare text lengths + text_lengths = [] + for model_id in model_ids: + if not model_errors[model_id] and not has_no_text(model_results[model_id]): + text_lengths.append( + f"{MODEL_CONFIGS[model_id].display}: {len(model_results[model_id]['raw_text'])} chars" + ) + + if text_lengths: + st.write("📝 Text lengths: " + ", ".join(text_lengths)) + else: + st.warning("No usable text extracted by any model") + + +def run_multiple_models(image, model_ids, custom_prompt): + """Run OCR with multiple models and display comparison results in parallel.""" + try: + model_logos = load_model_logos() + num_models = len(model_ids) + + # Create responsive layout + if num_models <= 2: + cols = st.columns(num_models) + col_rows = None + else: + cols = None + col_rows = [] + for i in range(0, num_models, 3): + remaining = min(3, num_models - i) + col_rows.append(st.columns(remaining)) + + # Create placeholders with headers and spinners + placeholders = {} + for i, model_id in enumerate(model_ids): + model_config = MODEL_CONFIGS[model_id] + model_name = model_config.display + + # Get appropriate column + if num_models <= 2: + column = cols[i] + else: + row_idx = i // 3 + col_idx = i % 3 + column = col_rows[row_idx][col_idx] + + with column: + # Add model header with logo + if model_name in model_logos: + header_html = f'

{model_name}

' + st.markdown(header_html, unsafe_allow_html=True) + else: + st.subheader(f"{model_name}") + + # Create spinner that will show during processing + spinner_placeholder = st.empty() + with spinner_placeholder: + st.info("Waiting to process...") + + # Create result area placeholder + result_area = st.empty() + + placeholders[model_id] = { + "column": column, + "spinner": spinner_placeholder, + "result_area": result_area, + } + + # Initialize result tracking + model_results = {} + model_times = {} + model_errors = {} + processed_count = 0 + + # Create a placeholder for comparison stats + stats_container = st.container() + + # Define callback function to update UI as each model completes + def process_callback(model_id, result): + nonlocal processed_count + proc_time = result.get("processing_time", 0) + + # Store results + model_results[model_id] = result + model_times[model_id] = proc_time + model_errors[model_id] = check_for_error(result) + + # Store in session state + st.session_state[f"{model_id}_result"] = result + + # Clear the spinner + placeholders[model_id]["spinner"].empty() + + # Update the result in the UI (without the header which is already shown) + with placeholders[model_id]["result_area"].container(): + text = result.get("raw_text", "") + if not isinstance(text, str): + text = str(text) + + st.text(f"Processing time: {proc_time:.2f}s") + st.markdown("##### Extracted Text") + + if text.startswith("Error:"): + st.error(text) + elif text == "No text found": + st.warning("No text found in the image") + else: + st.write(text) + + processed_count += 1 + + # When all models are done, show comparison stats + if processed_count == len(model_ids) and len(model_ids) > 1: + with stats_container: + display_comparison_stats(model_ids, model_results, model_times, model_errors) + + # Process each model in a separate thread + from concurrent.futures import ThreadPoolExecutor + + def process_model(model_id): + try: + # Don't update UI from the thread - just log it + start_time = time.time() + result = run_ocr( + image_input=image, + model_ids=model_id, + custom_prompt=custom_prompt, + track_metadata=False, + ) + # Ensure processing time is captured + if "processing_time" not in result: + result["processing_time"] = time.time() - start_time + return model_id, result + except Exception as e: + error_result = { + "raw_text": f"Error: {str(e)}", + "error": str(e), + "processing_time": 0, + "model": model_id, + } + return model_id, error_result + + # Use max_workers based on number of models + max_workers = min(len(model_ids), 5) + + # Create a dict to track which models are currently processing + processing_models = {model_id: False for model_id in model_ids} + + # Start processing in parallel + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Update all spinners to "Processing..." before submitting tasks + for model_id in model_ids: + with placeholders[model_id]["spinner"]: + st.info(f"Processing with {MODEL_CONFIGS[model_id].display}...") + processing_models[model_id] = True + + # Submit tasks + futures = {executor.submit(process_model, model_id): model_id for model_id in model_ids} + + # Process results as they complete + import concurrent.futures + + for future in concurrent.futures.as_completed(futures): + model_id, result = future.result() + process_callback(model_id, result) + + return model_results, model_times + + except Exception as e: + st.error(f"Error processing images: {e}") + return None, None + + +def main(): + """Main function to run the Streamlit app.""" + setup_page_config() + uploaded_file, custom_prompt = render_sidebar() + model_logos = load_model_logos() + render_header(model_logos) + add_clear_button() + + st.markdown("---") + st.markdown( + '

Extract structured text from images using your chosen OCR model!

', + unsafe_allow_html=True, + ) + + processed_image = None + + if uploaded_file is not None: + image = Image.open(uploaded_file) + processed_image = process_uploaded_image(image) + elif "processed_image" in st.session_state: + processed_image = st.session_state["processed_image"] + st.image(processed_image, caption="Uploaded Image", use_container_width=False) + + if processed_image is not None: + if st.button("Extract Text 🔍", type="primary"): + comparison_mode = st.session_state.get("comparison_mode", False) + + if comparison_mode: + comparison_models = st.session_state.get("comparison_models", [DEFAULT_MODEL.name]) + run_multiple_models(processed_image, comparison_models, custom_prompt) + else: + selected_model_id = st.session_state.get("selected_model_id", DEFAULT_MODEL.name) + run_single_model(processed_image, selected_model_id, custom_prompt) + else: + st.info("Upload an image and click 'Extract Text' to see the results here.") + + # Footer + st.markdown("---") + st.caption("ZenOCR - Comparing LLM OCR capabilities") + + +if __name__ == "__main__": + main() diff --git a/omni-reader/assets/docs/metrics.png b/omni-reader/assets/docs/metrics.png new file mode 100644 index 00000000..7f59b482 Binary files /dev/null and b/omni-reader/assets/docs/metrics.png differ diff --git a/omni-reader/assets/docs/pipeline_dags.png b/omni-reader/assets/docs/pipeline_dags.png new file mode 100644 index 00000000..90ca4cbc Binary files /dev/null and b/omni-reader/assets/docs/pipeline_dags.png differ diff --git a/omni-reader/assets/docs/streamlit.png b/omni-reader/assets/docs/streamlit.png new file mode 100644 index 00000000..ce18d03b Binary files /dev/null and b/omni-reader/assets/docs/streamlit.png differ diff --git a/omni-reader/assets/docs/visualization.png b/omni-reader/assets/docs/visualization.png new file mode 100644 index 00000000..d0390bfc Binary files /dev/null and b/omni-reader/assets/docs/visualization.png differ diff --git a/omni-reader/assets/logos/default.svg b/omni-reader/assets/logos/default.svg new file mode 100644 index 00000000..b709fab6 --- /dev/null +++ b/omni-reader/assets/logos/default.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/omni-reader/assets/logos/gemma.svg b/omni-reader/assets/logos/gemma.svg new file mode 100644 index 00000000..116ec367 --- /dev/null +++ b/omni-reader/assets/logos/gemma.svg @@ -0,0 +1 @@ +Gemma \ No newline at end of file diff --git a/omni-reader/assets/logos/microsoft.svg b/omni-reader/assets/logos/microsoft.svg new file mode 100644 index 00000000..5334aa7c --- /dev/null +++ b/omni-reader/assets/logos/microsoft.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/omni-reader/assets/logos/mistral.svg b/omni-reader/assets/logos/mistral.svg new file mode 100644 index 00000000..4e97da2a --- /dev/null +++ b/omni-reader/assets/logos/mistral.svg @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/omni-reader/assets/logos/ollama.svg b/omni-reader/assets/logos/ollama.svg new file mode 100644 index 00000000..cc887e3d --- /dev/null +++ b/omni-reader/assets/logos/ollama.svg @@ -0,0 +1 @@ +Ollama \ No newline at end of file diff --git a/omni-reader/assets/logos/openai.svg b/omni-reader/assets/logos/openai.svg new file mode 100644 index 00000000..3b4eff96 --- /dev/null +++ b/omni-reader/assets/logos/openai.svg @@ -0,0 +1,2 @@ + +OpenAI icon \ No newline at end of file diff --git a/omni-reader/assets/samples_for_ocr/easy_example.jpeg b/omni-reader/assets/samples_for_ocr/easy_example.jpeg new file mode 100644 index 00000000..7a158f02 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/easy_example.jpeg differ diff --git a/omni-reader/assets/samples_for_ocr/education_article_excerpt.webp b/omni-reader/assets/samples_for_ocr/education_article_excerpt.webp new file mode 100644 index 00000000..f5e8b3eb Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/education_article_excerpt.webp differ diff --git a/omni-reader/assets/samples_for_ocr/incomplete_sentence.png b/omni-reader/assets/samples_for_ocr/incomplete_sentence.png new file mode 100644 index 00000000..6510a1f4 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/incomplete_sentence.png differ diff --git a/omni-reader/assets/samples_for_ocr/lexus_vin_number.webp b/omni-reader/assets/samples_for_ocr/lexus_vin_number.webp new file mode 100644 index 00000000..cce02b1f Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/lexus_vin_number.webp differ diff --git a/omni-reader/assets/samples_for_ocr/montreal_signs.jpg b/omni-reader/assets/samples_for_ocr/montreal_signs.jpg new file mode 100644 index 00000000..24f014d6 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/montreal_signs.jpg differ diff --git a/omni-reader/assets/samples_for_ocr/paris_signs.jpg b/omni-reader/assets/samples_for_ocr/paris_signs.jpg new file mode 100644 index 00000000..da8936ae Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/paris_signs.jpg differ diff --git a/omni-reader/assets/samples_for_ocr/reporter_notes.png b/omni-reader/assets/samples_for_ocr/reporter_notes.png new file mode 100644 index 00000000..27d320b1 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/reporter_notes.png differ diff --git a/omni-reader/assets/samples_for_ocr/rx_prescription_clear.jpg b/omni-reader/assets/samples_for_ocr/rx_prescription_clear.jpg new file mode 100644 index 00000000..b7e8c077 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/rx_prescription_clear.jpg differ diff --git a/omni-reader/assets/samples_for_ocr/rx_prescription_unclear.png b/omni-reader/assets/samples_for_ocr/rx_prescription_unclear.png new file mode 100644 index 00000000..75440102 Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/rx_prescription_unclear.png differ diff --git a/omni-reader/assets/samples_for_ocr/tire_serial_number.jpg b/omni-reader/assets/samples_for_ocr/tire_serial_number.jpg new file mode 100644 index 00000000..2f2a0e6a Binary files /dev/null and b/omni-reader/assets/samples_for_ocr/tire_serial_number.jpg differ diff --git a/omni-reader/configs/batch_pipeline.yaml b/omni-reader/configs/batch_pipeline.yaml new file mode 100644 index 00000000..3400d2c4 --- /dev/null +++ b/omni-reader/configs/batch_pipeline.yaml @@ -0,0 +1,84 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pipeline configuration +build: batch-ocr-pipeline-20254007 +run_name: run_ocr + +# environment configuration +settings: + docker: + requirements: requirements.txt + required_integrations: + - aws + - s3 + python_package_installer: uv + environment: + OPENAI_API_KEY: ${{ env.OPENAI_API_KEY }} + MISTRAL_API_KEY: ${{ env.MISTRAL_API_KEY }} + +# enable flags +enable_artifact_metadata: True +enable_artifact_visualization: True +enable_cache: False +enable_step_logs: True + +# step configuration +steps: + load_images: + parameters: + image_folder: ./assets/samples_for_ocr + image_paths: [] + enable_cache: False + + run_ocr: + parameters: + custom_prompt: null + models: # can be model names or shorthands + - pixtral + - gemma3 + - llava-phi3 + - gpt4o + - granite + enable_cache: False + +# vision models configuration +models_registry: + - name: mistral/pixtral-12b-2409 + shorthand: pixtral + ocr_processor: litellm + provider: mistral + + - name: gpt-4o-mini + shorthand: gpt4o + ocr_processor: openai + + - name: gemma3:27b + shorthand: gemma3 + ocr_processor: ollama + + - name: llava:34b + shorthand: llava34b + ocr_processor: ollama + + - name: llava-phi3 + shorthand: llava-phi3 + ocr_processor: ollama + + - name: granite3.2-vision + shorthand: granite + ocr_processor: ollama diff --git a/omni-reader/configs/evaluation_pipeline.yaml b/omni-reader/configs/evaluation_pipeline.yaml new file mode 100644 index 00000000..f3f1f241 --- /dev/null +++ b/omni-reader/configs/evaluation_pipeline.yaml @@ -0,0 +1,55 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pipeline configuration +build: ocr-evaluation-pipeline-20254007 +run_name: ocr_evaluation_run + +# environment configuration +settings: + docker: + requirements: requirements.txt + required_integrations: + - aws + - s3 + python_package_installer: uv + environment: + OPENAI_API_KEY: ${{ env.OPENAI_API_KEY }} + MISTRAL_API_KEY: ${{ env.MISTRAL_API_KEY }} + +# enable flags +enable_artifact_metadata: True +enable_artifact_visualization: True +enable_cache: False +enable_step_logs: True + +# steps configuration +steps: + load_ocr_results: # loads OCR results from batch pipeline runs + enable_cache: False + parameters: + artifact_name: ocr_results + version: null + + load_ground_truth_texts: + enable_cache: False + parameters: + ground_truth_folder: ground_truth_texts + ground_truth_files: [] + + evaluate_models: + enable_cache: False diff --git a/omni-reader/ground_truth_texts/education_article_excerpt.txt b/omni-reader/ground_truth_texts/education_article_excerpt.txt new file mode 100644 index 00000000..9ebe4c8f --- /dev/null +++ b/omni-reader/ground_truth_texts/education_article_excerpt.txt @@ -0,0 +1,13 @@ +Only a matter of style? + +For educational purposes we analyse the opening pages of an 11-page article that appeared in The American Mathematical Monthly, Volume 102 Number 2 / February 1995. We have added line numbers in the right margin. + +line 4: Since in this article, squares don’t get alternating colours, it could be argued that the term “chessboard” is misplaced. + +line 4: The introduction of the name “B” seems unnecessary; it is used in the combination “the board B” in the text for Figure 1 and in line 7; in both cases just “the board” would have done fine. In line 77 occurs the last use of B, viz. in “X⊂B”, which is dubious since B was a board and not a set; in line 77, I would have preferred “Given a set X of cells”. + +line 7/8: The first move, being a move like any other, does not deserve a separate description. The term “step” is redundant. + +line 8: Why not “a move consists of”? + +line 10/11: At this stage the italics are puzzling, since a move is possible if, diff --git a/omni-reader/ground_truth_texts/incomplete_sentence.txt b/omni-reader/ground_truth_texts/incomplete_sentence.txt new file mode 100644 index 00000000..a5924119 --- /dev/null +++ b/omni-reader/ground_truth_texts/incomplete_sentence.txt @@ -0,0 +1 @@ +In mid-April Anglesey moved his family and entourage from Rome to Naples, there to await the arrival of diff --git a/omni-reader/ground_truth_texts/lexus_vin_number.txt b/omni-reader/ground_truth_texts/lexus_vin_number.txt new file mode 100644 index 00000000..767a51f4 --- /dev/null +++ b/omni-reader/ground_truth_texts/lexus_vin_number.txt @@ -0,0 +1 @@ +JTHBH5D2405012812 \ No newline at end of file diff --git a/omni-reader/ground_truth_texts/montreal_signs.txt b/omni-reader/ground_truth_texts/montreal_signs.txt new file mode 100644 index 00000000..c31b8cf6 --- /dev/null +++ b/omni-reader/ground_truth_texts/montreal_signs.txt @@ -0,0 +1,3 @@ +Basilique Notre-Dame +Place Royale +Place d’Armes \ No newline at end of file diff --git a/omni-reader/ground_truth_texts/paris_signs.txt b/omni-reader/ground_truth_texts/paris_signs.txt new file mode 100644 index 00000000..56b98b30 --- /dev/null +++ b/omni-reader/ground_truth_texts/paris_signs.txt @@ -0,0 +1,5 @@ +Palais Royal +Les Arts Décoratifs +Musée du LOUVRE +Église ST GERMAIN l’AUXERROIS +Musée Eugène DELACROIX \ No newline at end of file diff --git a/omni-reader/ground_truth_texts/reporter_notes.txt b/omni-reader/ground_truth_texts/reporter_notes.txt new file mode 100644 index 00000000..03596fac --- /dev/null +++ b/omni-reader/ground_truth_texts/reporter_notes.txt @@ -0,0 +1 @@ +An attempt to get more information about the Admiralty House meeting will be made in the House of Commons this afternoon. Labour M.P.s already have many questions to the Prime Minister asking for a statement. President Kennedy flew from London Airport last night to arrive in Washington this morning. He is to make a 30-minute nation-wide broadcast and television report on his talks with Mr. Khrushchev this evening. diff --git a/omni-reader/ground_truth_texts/rx_prescription_clear.txt b/omni-reader/ground_truth_texts/rx_prescription_clear.txt new file mode 100644 index 00000000..d0cc4230 --- /dev/null +++ b/omni-reader/ground_truth_texts/rx_prescription_clear.txt @@ -0,0 +1,6 @@ +Rx +Amoxicillin + Clavulanic acid (Co-Amoxiclav) 500/125 mg/tab #21 +Sig: Take one with food every 8 hours for 7 days + +Paracetamol 500 mg/tab #5 +Sig: Take one with food every 4 hours as needed for fever (temp. ≥ 37.8°C) diff --git a/omni-reader/ground_truth_texts/rx_prescription_unclear.txt b/omni-reader/ground_truth_texts/rx_prescription_unclear.txt new file mode 100644 index 00000000..dc53da0a --- /dev/null +++ b/omni-reader/ground_truth_texts/rx_prescription_unclear.txt @@ -0,0 +1,10 @@ +Rx + +Tab. Pansec 20 +1 + 0 + 1 + +Tab. Apitez 160 +0 + 1 + 0 + +Tab. Linagliptin / Linita 5mg +2 + 0 + 0 \ No newline at end of file diff --git a/omni-reader/ground_truth_texts/tire_serial_number.txt b/omni-reader/ground_truth_texts/tire_serial_number.txt new file mode 100644 index 00000000..8fc900d6 --- /dev/null +++ b/omni-reader/ground_truth_texts/tire_serial_number.txt @@ -0,0 +1 @@ +3702692432 \ No newline at end of file diff --git a/omni-reader/pipelines/__init__.py b/omni-reader/pipelines/__init__.py new file mode 100644 index 00000000..f90f1c83 --- /dev/null +++ b/omni-reader/pipelines/__init__.py @@ -0,0 +1,19 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OCR pipelines""" + +from pipelines.batch_pipeline import batch_ocr_pipeline, run_batch_ocr_pipeline +from pipelines.evaluation_pipeline import ocr_evaluation_pipeline, run_ocr_evaluation_pipeline diff --git a/omni-reader/pipelines/batch_pipeline.py b/omni-reader/pipelines/batch_pipeline.py new file mode 100644 index 00000000..913ffe14 --- /dev/null +++ b/omni-reader/pipelines/batch_pipeline.py @@ -0,0 +1,105 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OCR Batch Pipeline implementation for processing images with multiple models.""" + +import os +from typing import Any, Dict, List, Optional + +from dotenv import load_dotenv +from zenml import pipeline +from zenml.config import DockerSettings +from zenml.logger import get_logger + +from steps import ( + load_images, + run_ocr, +) + +load_dotenv() + +logger = get_logger(__name__) + +docker_settings = DockerSettings( + required_integrations=["s3", "aws"], + python_package_installer="uv", + requirements="requirements.txt", + environment={ + "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), + "MISTRAL_API_KEY": os.getenv("MISTRAL_API_KEY"), + }, +) + + +@pipeline(settings={"docker": docker_settings}) +def batch_ocr_pipeline( + image_paths: Optional[List[str]] = None, + image_folder: Optional[str] = None, + custom_prompt: Optional[str] = None, + models: List[str] = None, +) -> None: + """Run OCR batch processing pipeline with multiple models. + + Args: + image_paths: Optional list of specific image paths to process + image_folder: Optional folder to search for images + custom_prompt: Optional custom prompt to use for the models + models: List of model names to use for OCR + """ + images = load_images( + image_paths=image_paths, + image_folder=image_folder, + ) + + run_ocr( + images=images, + models=models, + custom_prompt=custom_prompt, + ) + + +def run_batch_ocr_pipeline(config: Dict[str, Any]) -> None: + """Run the OCR batch pipeline from a configuration dictionary. + + Args: + config: Dictionary containing configuration + + Returns: + None + """ + pipeline_instance = batch_ocr_pipeline.with_options( + enable_cache=config.get("enable_cache", False), + ) + + load_images_params = config.get("steps", {}).get("load_images", {}).get("parameters", {}) + image_folder = load_images_params.get("image_folder") + image_paths = load_images_params.get("image_paths", []) + if not image_folder and len(image_paths) == 0: + raise ValueError("Either image_folder or image_paths must be provided") + + run_ocr_params = config.get("steps", {}).get("run_ocr", {}).get("parameters", {}) + custom_prompt = run_ocr_params.get("custom_prompt") + selected_models = run_ocr_params.get("models", []) + if not selected_models or len(selected_models) == 0: + raise ValueError( + "No models found in the run_ocr step of the batch_ocr_pipeline config file. At least one model must be specified in the 'models' parameter." + ) + + pipeline_instance( + image_paths=image_paths, + image_folder=image_folder, + custom_prompt=custom_prompt, + models=selected_models, + ) diff --git a/omni-reader/pipelines/evaluation_pipeline.py b/omni-reader/pipelines/evaluation_pipeline.py new file mode 100644 index 00000000..a1b7a21c --- /dev/null +++ b/omni-reader/pipelines/evaluation_pipeline.py @@ -0,0 +1,99 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OCR Evaluation Pipeline implementation for comparing models using existing results.""" + +import os +from typing import Any, Dict, List, Optional + +from dotenv import load_dotenv +from zenml import pipeline +from zenml.config import DockerSettings +from zenml.logger import get_logger + +from steps import ( + evaluate_models, + load_ground_truth_texts, + load_ocr_results, +) + +load_dotenv() + +logger = get_logger(__name__) + +docker_settings = DockerSettings( + requirements="requirements.txt", + required_integrations=["s3", "aws"], + python_package_installer="uv", + environment={ + "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), + "MISTRAL_API_KEY": os.getenv("MISTRAL_API_KEY"), + }, +) + + +@pipeline(settings={"docker": docker_settings}) +def ocr_evaluation_pipeline( + ground_truth_folder: Optional[str] = None, + ground_truth_files: Optional[List[str]] = None, +) -> None: + """Run OCR evaluation pipeline comparing existing model results.""" + if not ground_truth_folder and not ground_truth_files: + raise ValueError( + "Either ground_truth_folder or ground_truth_files must be provided for evaluation" + ) + + model_results = load_ocr_results(artifact_name="ocr_results") + + ground_truth_df = load_ground_truth_texts( + model_results=model_results, + ground_truth_folder=ground_truth_folder, + ground_truth_files=ground_truth_files, + ) + + evaluate_models( + model_results=model_results, + ground_truth_df=ground_truth_df, + ) + + +def run_ocr_evaluation_pipeline(config: Dict[str, Any]) -> None: + """Run the OCR evaluation pipeline from a configuration dictionary. + + Args: + config: Dictionary containing configuration + + Returns: + None + """ + mode = config.get("parameters", {}).get("mode", "evaluation") + if mode != "evaluation": + logger.warning(f"Expected mode 'evaluation', but got '{mode}'. Proceeding anyway.") + + pipeline_instance = ocr_evaluation_pipeline.with_options( + enable_artifact_metadata=config.get("enable_artifact_metadata", True), + enable_artifact_visualization=config.get("enable_artifact_visualization", True), + enable_cache=config.get("enable_cache", False), + enable_step_logs=config.get("enable_step_logs", True), + ) + + load_ground_truth_texts_params = ( + config.get("steps", {}).get("load_ground_truth_texts", {}).get("parameters", {}) + ) + + pipeline_instance( + ground_truth_folder=load_ground_truth_texts_params.get("ground_truth_folder"), + ground_truth_files=load_ground_truth_texts_params.get("ground_truth_files", []), + ) diff --git a/omni-reader/requirements.txt b/omni-reader/requirements.txt new file mode 100644 index 00000000..0c69b01d --- /dev/null +++ b/omni-reader/requirements.txt @@ -0,0 +1,16 @@ +instructor +jiwer +jiter +importlib-metadata<7.0,>=1.4.0 +litellm +mistralai==1.0.3 +numpy<2.0,>=1.9.0 +openai==1.69.0 +Pillow==11.1.0 +polars==1.26.0 +pyarrow>=7.0.0 +python-dotenv +streamlit==1.44.0 +pydantic>=2.8.2,<2.9.0 +tqdm==4.66.4 +zenml>=0.80.0 diff --git a/omni-reader/run.py b/omni-reader/run.py new file mode 100644 index 00000000..4c2eaa28 --- /dev/null +++ b/omni-reader/run.py @@ -0,0 +1,419 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run OCR pipeline with or without ZenML tracking. + +This module provides two modes of operation: +1. UI Mode: Direct OCR with no metadata/artifact tracking (for Streamlit) +2. Pipeline Mode: Full ZenML pipeline with tracking +""" + +import argparse +import os +import time +from typing import Any, Dict, List, Optional, Union + +from dotenv import load_dotenv +from PIL import Image + +from pipelines.batch_pipeline import run_batch_ocr_pipeline +from pipelines.evaluation_pipeline import run_ocr_evaluation_pipeline +from utils.config import ( + get_image_paths, + list_available_ground_truth_files, + load_config, + override_batch_config, + override_evaluation_config, + print_config_summary, + select_config_path, + validate_batch_config, + validate_evaluation_config, +) +from utils.model_configs import DEFAULT_MODEL, MODEL_CONFIGS +from utils.ocr_processing import run_ocr + +load_dotenv() + + +def run_ocr_from_ui( + image: Union[str, Image.Image], + model: str, + custom_prompt: Optional[str] = None, +) -> Dict[str, Any]: + """Extract text directly using OCR model without ZenML tracking.""" + start_time = time.time() + + # Get model configuration based on model ID + if model not in MODEL_CONFIGS: + return { + "raw_text": f"Error: Model '{model}' not found in MODEL_CONFIGS", + "error": f"Invalid model: {model}", + "processing_time": 0, + "model": model, + } + + try: + # Use the unified run_ocr function with track_metadata=False for UI + result = run_ocr( + image_input=image, + model_ids=model, + custom_prompt=custom_prompt, + track_metadata=False, + ) + + # Ensure processing_time is properly set + if "processing_time" not in result: + result["processing_time"] = time.time() - start_time + + # Ensure model info is set + result["model"] = model + result["display_name"] = MODEL_CONFIGS[model].display + result["ocr_processor"] = MODEL_CONFIGS[model].ocr_processor + + return result + except Exception as e: + processing_time = time.time() - start_time + return { + "raw_text": f"Error: Failed to extract text - {str(e)}", + "error": str(e), + "processing_time": processing_time, + "model": model, + "display_name": MODEL_CONFIGS[model].display, + "ocr_processor": MODEL_CONFIGS[model].ocr_processor, + } + + +def run_models_in_parallel( + image_path: Union[str, Image.Image], + model_ids: List[str], + custom_prompt: Optional[str] = None, +) -> Dict[str, Dict[str, Any]]: + """Process an image with multiple models in parallel.""" + try: + # Use the unified run_ocr function for parallel processing + results = run_ocr( + image_input=image_path, + model_ids=model_ids, + custom_prompt=custom_prompt, + track_metadata=False, + ) + + # Display progress in CLI mode + print(f"Processed image with {len(model_ids)} models") + return results + except Exception as e: + print(f"Error processing models in parallel: {str(e)}") + results = {} + + # Create error results for each model + for model_id in model_ids: + results[model_id] = { + "raw_text": f"Error: {str(e)}", + "error": str(e), + "processing_time": 0, + "model": model_id, + } + + return results + + +def list_supported_models(): + """List all supported models.""" + print("\nSupported models:") + print("-" * 70) + print(f"{'Model ID':<25} {'Display Name':<30} {'OCR Processor':<15}") + print("-" * 70) + + for model_id, config in MODEL_CONFIGS.items(): + print(f"{model_id:<25} {config.display:<30} {config.ocr_processor:<15}") + if config.provider: + print(f"{'Provider':<25} {config.provider:<30}") + + print("\nDefault model:", DEFAULT_MODEL.name) + print("-" * 70) + + +def format_model_results(model_id, result): + """Format results for a specific model.""" + model_config = MODEL_CONFIGS.get(model_id, None) + model_display = model_config.display if model_config else model_id + + output = f"\n{model_display} results:" + + if "error" in result: + output += f"\n❌ Error: {result.get('error', 'Unknown error')}" + else: + text = result["raw_text"] + if len(text) > 150: + text = f"{text[:150]}..." + output += f"\n✅ Text: {text}" + + output += f"\n⏱️ Processing time: {result.get('processing_time', 0):.2f}s" + + if "confidence" in result and result["confidence"] is not None: + output += f"\n🎯 Confidence: {result['confidence']:.2%}" + + return output + + +def run_ui_mode(args, parser): + """Run the application in streamlit UI mode without ZenML tracking.""" + if args.list_models: + list_supported_models() + return + + if args.image_paths: + image_path = args.image_paths[0] # Take the first image for UI mode + elif args.image_folder: + image_paths = get_image_paths(args.image_folder) + if not image_paths: + print(f"No images found in directory: {args.image_folder}") + return + image_path = image_paths[0] # Take the first image for UI mode + else: + parser.error("Error: Please provide an image path or folder") + return + + if not os.path.exists(image_path): + parser.error(f"Error: Image file '{image_path}' not found.") + return + + start_time = time.time() + + if args.models == "all": + # Run all models in parallel + print(f"Processing image with all {len(MODEL_CONFIGS)} models in parallel...") + results = run_models_in_parallel( + image_path, + list(MODEL_CONFIGS.keys()), + args.custom_prompt, + ) + + successful_models = sum(1 for result in results.values() if "error" not in result) + failed_models = len(results) - successful_models + + print("\n" + "=" * 50) + print(f"OCR COMPARISON RESULTS ({successful_models} successful, {failed_models} failed)") + print("=" * 50) + + # individual model results + for model_id, result in results.items(): + print(format_model_results(model_id, result)) + + print(f"\n⏱️ Total time: {time.time() - start_time:.2f}s") + print("=" * 50) + + elif "," in args.models: + # Run specific models in parallel + model_ids = [model_id.strip() for model_id in args.models.split(",")] + + invalid_models = [model_id for model_id in model_ids if model_id not in MODEL_CONFIGS] + if invalid_models: + print(f"Error: The following models are not supported: {', '.join(invalid_models)}") + print("Use --list-models to see all supported models.") + return + + print(f"Processing image with {len(model_ids)} selected models in parallel...") + results = run_models_in_parallel( + image_path, + model_ids, + args.custom_prompt, + ) + + successful_models = sum(1 for result in results.values() if "error" not in result) + failed_models = len(results) - successful_models + + print("\n" + "=" * 50) + print(f"OCR COMPARISON RESULTS ({successful_models} successful, {failed_models} failed)") + print("=" * 50) + + # individual model results + for model_id, result in results.items(): + print(format_model_results(model_id, result)) + + print(f"\n⏱️ Total time: {time.time() - start_time:.2f}s") + print("=" * 50) + + else: + # Run a single model + model_id = args.models + if model_id not in MODEL_CONFIGS: + print(f"Error: Model '{model_id}' not supported.") + print("Use --list-models to see all supported models.") + return + + print(f"\nProcessing with {model_id} model...") + result = run_ocr_from_ui( + image_path, + model_id, + args.custom_prompt, + ) + + print("\n" + "=" * 50) + print(f"OCR RESULT FOR {model_id}") + print("=" * 50) + print(format_model_results(model_id, result)) + print("=" * 50) + + +def run_pipeline_mode(args, parser): + """Run the application in full pipeline mode with ZenML tracking.""" + # List available ground truth files if requested + if args.list_ground_truth_files: + gt_files = list_available_ground_truth_files(directory=args.ground_truth_dir) + if gt_files: + print("Available ground truth files:") + for i, file in enumerate(gt_files): + print(f" {i + 1}. {file}") + else: + print(f"No ground truth files found in '{args.ground_truth_dir}'") + return + + # Determine pipeline mode and select config path + evaluation_mode = args.eval + + if args.config: + config_path = args.config + else: + config_path = select_config_path(evaluation_mode) + print(f"Auto-selecting config file: {config_path}") + + if not os.path.exists(config_path): + parser.error(f"Config file not found: {config_path}") + return + + # Load the configuration + try: + config = load_config(config_path) + except (ValueError, FileNotFoundError) as e: + parser.error(f"Error loading configuration: {str(e)}") + return + + cli_args = { + "image_paths": args.image_paths, + "image_folder": args.image_folder, + "custom_prompt": args.custom_prompt, + "ground_truth_dir": args.ground_truth_dir, + } + + # Override configuration with CLI arguments if provided + try: + if evaluation_mode: + config = override_evaluation_config(config, cli_args) + validate_evaluation_config(config) + else: + config = override_batch_config(config, cli_args) + validate_batch_config(config) + except ValueError as e: + parser.error(f"Configuration error: {str(e)}") + return + + print_config_summary(config, is_evaluation_config=evaluation_mode) + + try: + if evaluation_mode: + print("Running OCR Evaluation Pipeline...") + run_ocr_evaluation_pipeline(config) + else: + print("Running OCR Batch Pipeline...") + run_batch_ocr_pipeline(config) + except Exception as e: + print(f"Error running pipeline: {str(e)}") + return + + +def main(): + """Main entry point for the OCR tool.""" + parser = argparse.ArgumentParser( + description="Run OCR between vision models with or without ZenML tracking" + ) + + # Mode selection + parser.add_argument( + "--ui_mode", + action="store_true", + help="Run in UI mode without ZenML tracking (for Streamlit)", + ) + + # Config file options (pipeline mode) + config_group = parser.add_argument_group("Pipeline Mode Configuration") + config_group.add_argument( + "--config", + type=str, + help="Path to YAML configuration file (for pipeline mode)", + ) + config_group.add_argument( + "--eval", + action="store_true", + help="Run in evaluation pipeline mode (defaults to batch pipeline if not specified)", + ) + + # Ground truth utilities (pipeline mode) + gt_group = parser.add_argument_group("Ground truth utilities (pipeline mode)") + gt_group.add_argument( + "--list-ground-truth-files", + action="store_true", + help="List available ground truth files and exit", + ) + gt_group.add_argument( + "--ground-truth-dir", + type=str, + default="ground_truth_texts", + help="Directory to look for ground truth files (for --list-ground-truth-files)", + ) + + # Quick access options (shared between modes) + input_group = parser.add_argument_group("Input options (shared)") + input_group.add_argument( + "--image-paths", + nargs="+", + help="Paths to images to process", + ) + input_group.add_argument( + "--image-folder", + type=str, + help="Folder containing images to process", + ) + input_group.add_argument( + "--custom-prompt", + type=str, + dest="custom_prompt", + help="Custom prompt to use for OCR models", + ) + + # UI mode specific options + ui_group = parser.add_argument_group("UI Mode Options") + ui_group.add_argument( + "--models", + type=str, + default=DEFAULT_MODEL.name, + help="Model(s) to use: a specific model ID, 'all' to compare all, or a comma-separated list (for UI mode)", + ) + ui_group.add_argument( + "--list-models", + action="store_true", + help="List all supported models and exit (for UI mode)", + ) + + args = parser.parse_args() + + if args.ui_mode: + run_ui_mode(args, parser) + else: + run_pipeline_mode(args, parser) + + +if __name__ == "__main__": + main() diff --git a/omni-reader/schemas/__init__.py b/omni-reader/schemas/__init__.py new file mode 100644 index 00000000..0805615d --- /dev/null +++ b/omni-reader/schemas/__init__.py @@ -0,0 +1,20 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .image_description import ImageDescription +from .ocr_result import ( + OCRResult, + OCRResultMapping, +) diff --git a/omni-reader/schemas/image_description.py b/omni-reader/schemas/image_description.py new file mode 100644 index 00000000..f7a24e4a --- /dev/null +++ b/omni-reader/schemas/image_description.py @@ -0,0 +1,27 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Schemas for LLM responses.""" + +from typing import Optional + +from pydantic import BaseModel + + +class ImageDescription(BaseModel): + """Base model for OCR results.""" + + raw_text: str + confidence: Optional[float] = None diff --git a/omni-reader/schemas/ocr_result.py b/omni-reader/schemas/ocr_result.py new file mode 100644 index 00000000..43aa74bd --- /dev/null +++ b/omni-reader/schemas/ocr_result.py @@ -0,0 +1,36 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Schemas for OCR results.""" + +from typing import Dict, List + +from pydantic import BaseModel, RootModel + + +class OCRResult(BaseModel): + """OCR result for a single image.""" + + id: int + image_name: str + raw_text: str + processing_time: float + confidence: float + + +class OCRResultMapping(RootModel): + """Each model name maps to a list of OCRResult entries.""" + + root: Dict[str, List[OCRResult]] diff --git a/omni-reader/steps/__init__.py b/omni-reader/steps/__init__.py new file mode 100644 index 00000000..ca6c1d0f --- /dev/null +++ b/omni-reader/steps/__init__.py @@ -0,0 +1,22 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .evaluate_models import evaluate_models +from .loaders import ( + load_ground_truth_texts, + load_images, + load_ocr_results, +) +from .run_ocr import run_ocr \ No newline at end of file diff --git a/omni-reader/steps/evaluate_models.py b/omni-reader/steps/evaluate_models.py new file mode 100644 index 00000000..a7d564c0 --- /dev/null +++ b/omni-reader/steps/evaluate_models.py @@ -0,0 +1,235 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the steps for evaluating the OCR models.""" + +from typing import Dict + +import polars as pl +from typing_extensions import Annotated +from zenml import log_metadata, step +from zenml.types import HTMLString + +from utils import ( + calculate_custom_metrics, + calculate_model_similarities, + compare_multi_model, + create_model_comparison_card, + create_summary_visualization, + get_model_info, +) + + +@step(enable_cache=False) +def evaluate_models( + model_results: pl.DataFrame, + ground_truth_df: pl.DataFrame, +) -> Annotated[HTMLString, "ocr_visualization"]: + """Compare the performance of multiple configurable models with visualization. + + Args: + model_results: Dictionary containing single or multiple model ocr results + ground_truth_df: DataFrame containing ground truth texts + + Returns: + HTML visualization of the evaluation results + """ + if model_results is None or len(model_results.columns) == 0: + raise ValueError("At least one model is required for evaluation") + + if ground_truth_df is None or ground_truth_df.is_empty(): + raise ValueError("Ground truth data is required for evaluation") + + gt_df = ground_truth_df + + # --- 2. Build model info for evaluation models --- + model_keys = list(model_results.columns) + model_info = {} + model_displays = [] + model_prefixes = {} + for model_name in model_keys: + display, prefix = get_model_info(model_name) + model_info[model_name] = (display, prefix) + model_displays.append(display) + model_prefixes[display] = prefix + + # --- 3. Convert DataFrame rows to dictionaries --- + model_results_dict = {} + for model_name in model_keys: + model_data = model_results[model_name].to_dicts() + model_results_dict[model_name] = pl.DataFrame(model_data) + + # --- 4. Merge evaluation models' results --- + base_model = model_keys[0] + base_display, base_prefix = model_info[base_model] + merged_results = model_results_dict[base_model].clone() + for i, model_name in enumerate(model_keys[1:], start=1): + disp, pref = model_info[model_name] + suffix = f"_{pref}" if i > 1 else "_right" + merged_results = merged_results.join( + model_results_dict[model_name], + on=["id", "image_name"], + how="inner", + suffix=suffix, + ) + + # --- 5. Join ground truth data if available --- + if gt_df is not None: + merged_results = merged_results.join( + gt_df, on=["id", "image_name"], how="inner", suffix="_gt" + ) + + # --- 6. Calculate processing times for evaluation models --- + all_model_times = {} + for model_name, df in model_results.items(): + disp, pref = model_info[model_name] + time_key = f"avg_{pref}_time" + all_model_times[time_key] = df.select("processing_time").to_series().mean() + all_model_times[f"{pref}_display"] = disp + + fastest_model_time, fastest_key = min( + [(time, key) for key, time in all_model_times.items() if not key.endswith("_display")], + key=lambda x: x[0], + ) + fastest_prefix = fastest_key.replace("avg_", "").replace("_time", "") + fastest_display = all_model_times.get(f"{fastest_prefix}_display", fastest_prefix) + + # --- 7. Per-image evaluation: compute metrics, error analysis, and build per-image cards --- + evaluation_metrics = [] + image_cards_html = "" + gt_text_col = "ground_truth_text" + + # Check if we have ground truth data in our joined dataset + if gt_text_col not in merged_results.columns and "raw_text_gt" in merged_results.columns: + gt_text_col = "raw_text_gt" # Fall back to legacy ground truth model format + + for row in merged_results.iter_rows(named=True): + if gt_text_col not in row: + continue + ground_truth = row[gt_text_col] + model_texts = {} + model_texts[base_display] = row["raw_text"] + for i, mkey in enumerate(model_keys[1:], start=1): + disp, pref = model_info[mkey] + col = "raw_text_right" if i == 1 else f"raw_text_{pref}" + if col in row: + model_texts[disp] = row[col] + row_metrics = calculate_custom_metrics(ground_truth, model_texts, list(model_texts.keys())) + error_analysis = compare_multi_model(ground_truth, model_texts) + result_metrics = {"id": row["id"], "image_name": row["image_name"]} + for disp in model_texts.keys(): + if disp in row_metrics: + for met_name, val in row_metrics[disp].items(): + result_metrics[f"{disp} {met_name}"] = val + for disp, errs in error_analysis.items(): + for met_name, val in errs.items(): + result_metrics[f"{disp} {met_name}"] = val + evaluation_metrics.append(result_metrics) + + # Merge GT Similarity values from row_metrics into error_analysis + for disp in model_texts.keys(): + if disp in row_metrics and "GT Similarity" in row_metrics[disp]: + if disp not in error_analysis: + error_analysis[disp] = {} + error_analysis[disp]["GT Similarity"] = row_metrics[disp]["GT Similarity"] + + comparison_card = create_model_comparison_card( + image_name=row["image_name"], + ground_truth=ground_truth, + model_texts=model_texts, + model_metrics=error_analysis, + ) + image_cards_html += comparison_card + + # --- 8. Compute average metrics for evaluation models --- + model_metric_averages = {d: {} for d in model_displays} + if evaluation_metrics: + df_eval = pl.DataFrame(evaluation_metrics) + for disp in model_displays: + cer_col = f"{disp} CER" + wer_col = f"{disp} WER" + sim_col = f"{disp} GT Similarity" + if cer_col in df_eval.columns: + model_metric_averages[disp]["CER"] = df_eval[cer_col].mean() + if wer_col in df_eval.columns: + model_metric_averages[disp]["WER"] = df_eval[wer_col].mean() + if sim_col in df_eval.columns: + model_metric_averages[disp]["GT Similarity"] = df_eval[sim_col].mean() + for disp in model_displays: + pref = model_prefixes[disp] + tkey = f"avg_{pref}_time" + if tkey in all_model_times: + model_metric_averages[disp]["Proc. Time"] = all_model_times[tkey] + + # --- 9. Calculate similarity matrix for evaluation models only --- + sim_results = [] + for row in merged_results.iter_rows(named=True): + texts_map = {} + texts_map[base_display] = row.get("raw_text", "") + for i, mkey in enumerate(model_keys[1:], start=1): + disp, pref = model_info[mkey] + col = "raw_text_right" if i == 1 else f"raw_text_{pref}" + texts_map[disp] = row.get(col, "") + sim_results.append(texts_map) + similarities = {} + if len(model_displays) > 1: + similarities = calculate_model_similarities( + results=[ + { + f"raw_text_{disp.lower().replace(' ', '_')}": texts_map[disp] + for disp in model_displays + } + for texts_map in sim_results + ], + model_displays=model_displays, + ) + + # --- 10. Build time comparison info --- + time_comparison = { + **all_model_times, + "fastest_model": fastest_display, + "model_count": len(model_keys), + } + if len(model_keys) >= 2: + d1, p1 = model_info[model_keys[0]] + d2, p2 = model_info[model_keys[1]] + tk1, tk2 = f"avg_{p1}_time", f"avg_{p2}_time" + if tk1 in all_model_times and tk2 in all_model_times: + time_comparison["time_difference"] = abs(all_model_times[tk1] - all_model_times[tk2]) + + # Log metadata (customize the metadata_dict as needed) + log_metadata( + metadata={ + "fastest_model": fastest_display, + "model_count": len(model_keys), + } + ) + + summary_html = create_summary_visualization( + model_metrics=model_metric_averages, + time_comparison=time_comparison, + similarities=similarities, + ) + + # --- 11. Combine summary and per-image details --- + final_html = f""" + {summary_html} +
+

Sample Results

+ {image_cards_html} +
+ """ + + return HTMLString(final_html) diff --git a/omni-reader/steps/loaders.py b/omni-reader/steps/loaders.py new file mode 100644 index 00000000..c24cf119 --- /dev/null +++ b/omni-reader/steps/loaders.py @@ -0,0 +1,204 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the ground truth and OCR results loader steps.""" + +import glob +import os +from typing import Dict, List, Optional + +import polars as pl +from typing_extensions import Annotated +from zenml import log_metadata, step +from zenml.artifacts.utils import load_artifact +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step() +def load_images( + image_paths: Optional[List[str]] = None, + image_folder: Optional[str] = None, +) -> List[str]: + """Load images for OCR processing. + + This step loads images from specified paths or by searching for + patterns in a given folder. + + Args: + image_paths: Optional list of specific image paths to load + image_folder: Optional folder to search for images. + + Returns: + List of validated image file paths + """ + all_images = [] + + if image_paths: + all_images.extend(image_paths) + logger.info(f"Added {len(image_paths)} directly specified images") + + if image_folder: + patterns_to_use = ["*.jpg", "*.jpeg", "*.png", "*.webp", "*.tiff"] + + for pattern in patterns_to_use: + full_pattern = os.path.join(image_folder, pattern) + matching_files = glob.glob(full_pattern) + if matching_files: + all_images.extend(matching_files) + logger.info( + f"Found {len(matching_files)} images matching pattern {pattern}" + ) + + # Validate image paths + valid_images = [] + for path in all_images: + if os.path.isfile(path): + valid_images.append(path) + else: + logger.warning(f"Image not found: {path}") + + # Log metadata about the loaded images + image_names = [os.path.basename(path) for path in valid_images] + image_extensions = [ + os.path.splitext(path)[1].lower() for path in valid_images + ] + + extension_counts = {} + for ext in image_extensions: + if ext in extension_counts: + extension_counts[ext] += 1 + else: + extension_counts[ext] = 1 + + log_metadata( + metadata={ + "loaded_images": { + "total_count": len(valid_images), + "extensions": extension_counts, + "image_names": image_names, + } + } + ) + + logger.info(f"Successfully loaded {len(valid_images)} valid images") + + return valid_images + + +@step(enable_cache=False) +def load_ground_truth_texts( + model_results: pl.DataFrame, + ground_truth_folder: Optional[str] = None, + ground_truth_files: Optional[List[str]] = None, +) -> Annotated[pl.DataFrame, "ground_truth"]: + """Load ground truth texts using image names found in model results.""" + if not ground_truth_folder and not ground_truth_files: + raise ValueError( + "Either ground_truth_folder or ground_truth_files must be provided" + ) + + # Get the first model column to extract image names + first_model_column = list(model_results.columns)[0] + + image_names = model_results[first_model_column]["image_name"].to_list() + + logger.info(f"Extracted {len(image_names)} image names") + + gt_path_map = {} + + if ground_truth_folder: + for f in os.listdir(ground_truth_folder): + if f.endswith(".txt"): + base = os.path.splitext(f)[0] + gt_path_map[base] = os.path.join(ground_truth_folder, f) + elif ground_truth_files: + for path in ground_truth_files: + base = os.path.splitext(os.path.basename(path))[0] + gt_path_map[base] = path + + data = [] + missing = [] + + for i, img_name in enumerate(image_names): + base_name = os.path.splitext(img_name)[0] + gt_path = gt_path_map.get(base_name) + + if not gt_path or not os.path.exists(gt_path): + missing.append(img_name) + continue + + try: + with open(gt_path, "r", encoding="utf-8") as f: + text = f.read().strip() + data.append( + { + "id": i, + "image_name": img_name, + "raw_text": text, + "processing_time": 0, + "confidence": 1.0, + } + ) + except Exception as e: + logger.warning(f"Failed to read ground truth for {img_name}: {e}") + + if missing: + logger.warning( + f"Missing ground truth files for {len(missing)} images: {missing[:5]}{'...' if len(missing) > 5 else ''}" + ) + + if not data: + raise ValueError("No ground truth files could be loaded.") + + return pl.DataFrame(data) + + +@step() +def load_ocr_results( + artifact_name: str = "ocr_results", + version: Optional[int] = None, +) -> Annotated[Dict[str, pl.DataFrame], "ocr_results"]: + """Load OCR results from ZenML artifact. + + Args: + artifact_name: Name of the ZenML artifact + version: Version of the ZenML artifact + + Returns: + dict: Dictionary mapping model names to OCR results DataFrames + + Raises: + ValueError: If the parameters are invalid or the dataset cannot be loaded + """ + try: + client = Client() + + artifact = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=version + ) + + ocr_results = load_artifact(artifact.id) + + logger.info( + f"Successfully loaded OCR results for {len(ocr_results)} models" + ) + + return ocr_results + except Exception as e: + logger.error(f"Failed to load OCR results: {str(e)}") + raise diff --git a/omni-reader/steps/run_ocr.py b/omni-reader/steps/run_ocr.py new file mode 100644 index 00000000..46c926cb --- /dev/null +++ b/omni-reader/steps/run_ocr.py @@ -0,0 +1,128 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND. +"""Unified step for running OCR with a one or multiple models.""" + +import os +import time +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Tuple + +import polars as pl +from tqdm import tqdm +from typing_extensions import Annotated +from zenml import log_metadata, step +from zenml.logger import get_logger +from zenml.types import HTMLString + +from utils.model_configs import MODEL_CONFIGS +from utils.ocr_processing import process_images_with_model +from utils.visualizations import create_ocr_batch_visualization + +logger = get_logger(__name__) + + +@step() +def run_ocr( + images: List[str], + models: List[str], + custom_prompt: Optional[str] = None, +) -> Tuple[ + Annotated[pl.DataFrame, "ocr_results"], + Annotated[HTMLString, "ocr_batch_visualization"], +]: + """Extract text from images using multiple models in parallel. + + Args: + images: List of paths to image files + models: List of model names to use + custom_prompt: Optional custom prompt to override the default prompt + + Returns: + pl.DataFrame: Combined results from all models with OCR results + + Raises: + ValueError: If any model_name is not supported + """ + for model in models: + if model not in MODEL_CONFIGS: + supported = ", ".join(MODEL_CONFIGS.keys()) + raise ValueError(f"Unsupported model: {model}. Supported models: {supported}") + + logger.info(f"Running OCR with {len(models)} models on {len(images)} images.") + + model_dfs = {} + performance_metrics = {} + + with ThreadPoolExecutor(max_workers=min(len(models), 5)) as executor: + futures = { + model: executor.submit( + process_images_with_model, + model_config=MODEL_CONFIGS[model], + images=images, + custom_prompt=custom_prompt, + track_metadata=True, + ) + for model in models + } + with tqdm(total=len(models), desc="Processing models") as pbar: + for model, future in futures.items(): + start = time.time() + try: + results = future.result() + results = results.with_columns( + pl.lit(model).alias("model_name"), + pl.lit(MODEL_CONFIGS[model].display).alias("model_display_name"), + ) + model_dfs[model] = results + + performance_metrics[model] = { + "total_time": time.time() - start, + "images_processed": len(images), + } + except Exception as e: + logger.error(f"Error processing {model}: {e}") + error_df = pl.DataFrame( + { + "id": list(range(len(images))), + "image_name": [os.path.basename(img) for img in images], + "raw_text": [f"Error: {e}"] * len(images), + "processing_time": [0.0] * len(images), + "confidence": [0.0] * len(images), + "error": [str(e)] * len(images), + "model_name": [model] * len(images), + "model_display_name": [MODEL_CONFIGS[model].display] * len(images), + } + ) + model_dfs[model] = error_df + performance_metrics[model] = { + "error": str(e), + "total_time": time.time() - start, + "images_processed": 0, + } + finally: + pbar.update(1) + + combined_results = pl.concat(list(model_dfs.values()), how="diagonal") + + # Generate HTML visualization + html_visualization = create_ocr_batch_visualization(combined_results) + + log_metadata( + metadata={ + "ocr_results_artifact_name": "ocr_results", + "ocr_results_artifact_type": "polars.DataFrame", + "ocr_batch_visualization_artifact_name": "ocr_batch_visualization", + "ocr_batch_visualization_artifact_type": "zenml.types.HTMLString", + }, + ) + + return combined_results, html_visualization diff --git a/omni-reader/utils/__init__.py b/omni-reader/utils/__init__.py new file mode 100644 index 00000000..f613bee3 --- /dev/null +++ b/omni-reader/utils/__init__.py @@ -0,0 +1,54 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .encode_image import encode_image +from .metrics import ( + analyze_errors, + calculate_custom_metrics, + calculate_model_similarities, + compare_multi_model, + find_best_model, + normalize_text, +) +from .visualizations import ( + create_metrics_table, + create_comparison_table, + create_model_card_with_logo, + create_model_comparison_card, + create_model_similarity_matrix, + create_summary_visualization, + create_ocr_batch_visualization +) +from .ocr_processing import ( + log_image_metadata, + log_error_metadata, + log_summary_metadata, + process_images_with_model, + process_image, +) +from .prompt import ( + get_prompt, + ImageDescription, +) +from .model_configs import ( + MODEL_CONFIGS, + DEFAULT_MODEL, + get_model_info, + model_registry, + ModelConfig, + get_model_prefix, +) +from .extract_json import try_extract_json_from_response diff --git a/omni-reader/utils/config.py b/omni-reader/utils/config.py new file mode 100644 index 00000000..6bd7949e --- /dev/null +++ b/omni-reader/utils/config.py @@ -0,0 +1,271 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for handling configuration.""" + +import glob +import os +from typing import Any, Dict, List, Optional + +import yaml + + +def load_config(config_path: str) -> Dict[str, Any]: + """Load configuration from YAML file.""" + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Configuration file does not exist: {config_path}") + + with open(config_path, "r") as f: + try: + config = yaml.safe_load(f) + return config + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML configuration file: {e}") + + +def validate_batch_config(config: Dict[str, Any]) -> None: + """Validate batch pipeline configuration.""" + if "steps" not in config: + raise ValueError("Missing required 'steps' section in batch configuration") + + steps = config.get("steps", {}) + + if "load_images" not in steps: + raise ValueError("Missing required 'load_images' step in batch pipeline configuration") + + if "run_ocr" not in steps: + raise ValueError("Missing required 'run_ocr' step in batch pipeline configuration") + + load_images_params = steps.get("load_images", {}).get("parameters", {}) + if not load_images_params.get("image_folder") and not load_images_params.get("image_paths"): + raise ValueError( + "Either image_folder or image_paths must be provided in load_images.parameters" + ) + + image_folder = load_images_params.get("image_folder") + if image_folder and not os.path.isdir(image_folder): + raise ValueError(f"Image folder does not exist: {image_folder}") + + run_ocr_params = steps.get("run_ocr", {}).get("parameters", {}) + if "models" not in run_ocr_params or not run_ocr_params["models"]: + raise ValueError("At least one model must be specified in run_ocr.parameters.models") + + if "models_registry" not in config or not config["models_registry"]: + raise ValueError("models_registry section is required with at least one model definition") + + +def validate_evaluation_config(config: Dict[str, Any]) -> None: + """Validate evaluation pipeline configuration.""" + if "steps" not in config: + raise ValueError("Missing required 'steps' section in evaluation configuration") + + steps = config.get("steps", {}) + + if "load_ocr_results" not in steps: + raise ValueError( + "Missing required 'load_ocr_results' step in evaluation pipeline configuration" + ) + + if "load_ground_truth_texts" not in steps: + raise ValueError( + "Missing required 'load_ground_truth_texts' step in evaluation pipeline configuration" + ) + + gt_params = steps.get("load_ground_truth_texts", {}).get("parameters", {}) + gt_folder = gt_params.get("ground_truth_folder") + if gt_folder and not os.path.isdir(gt_folder): + raise ValueError(f"Ground truth folder does not exist: {gt_folder}") + + +def override_batch_config(config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]: + """Override batch pipeline configuration with command-line arguments.""" + modified_config = {**config} + + steps = modified_config.get("steps", {}) + + if "load_images" not in steps: + steps["load_images"] = {"parameters": {}} + elif "parameters" not in steps["load_images"]: + steps["load_images"]["parameters"] = {} + + if cli_args.get("image_paths"): + steps["load_images"]["parameters"]["image_paths"] = cli_args["image_paths"] + + if cli_args.get("image_folder"): + steps["load_images"]["parameters"]["image_folder"] = cli_args["image_folder"] + + if "run_ocr" not in steps: + steps["run_ocr"] = {"parameters": {}} + elif "parameters" not in steps["run_ocr"]: + steps["run_ocr"]["parameters"] = {} + + if cli_args.get("custom_prompt"): + steps["run_ocr"]["parameters"]["custom_prompt"] = cli_args["custom_prompt"] + + return modified_config + + +def override_evaluation_config(config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]: + """Override evaluation pipeline configuration with command-line arguments.""" + modified_config = {**config} + + steps = modified_config.get("steps", {}) + + if "load_ground_truth_texts" not in steps: + steps["load_ground_truth_texts"] = {"parameters": {}} + elif "parameters" not in steps["load_ground_truth_texts"]: + steps["load_ground_truth_texts"]["parameters"] = {} + + if cli_args.get("ground_truth_dir"): + steps["load_ground_truth_texts"]["parameters"]["ground_truth_folder"] = cli_args[ + "ground_truth_dir" + ] + + return modified_config + + +def print_batch_config_summary(config: Dict[str, Any]) -> None: + """Print a summary of the batch pipeline configuration.""" + steps = config.get("steps", {}) + + print("\n===== Batch OCR Pipeline Configuration =====") + print(f"Build: {config.get('build', 'N/A')}") + print(f"Run name: {config.get('run_name', 'N/A')}") + + # Print caching and logging info + print(f"Cache enabled: {config.get('enable_cache', False)}") + print(f"Step logs enabled: {config.get('enable_step_logs', False)}") + + # Print input information + load_params = steps.get("load_images", {}).get("parameters", {}) + + image_paths = load_params.get("image_paths", []) + if image_paths: + print(f"Input images: {len(image_paths)} specified") + + image_folder = load_params.get("image_folder") + if image_folder: + print(f"Input folder: {image_folder}") + try: + num_images = len(get_image_paths(image_folder)) + print(f"Found {num_images} images in folder") + except Exception as e: + print(f"Unable to access image folder: {e}") + + # Print model information + run_ocr_params = steps.get("run_ocr", {}).get("parameters", {}) + models = run_ocr_params.get("models", []) + if models: + print(f"Models to run: {', '.join(models)}") + + # Print custom prompt if provided + custom_prompt = run_ocr_params.get("custom_prompt") + if custom_prompt: + print(f"Custom prompt: {custom_prompt}") + + # List models from registry + models_registry = config.get("models_registry", []) + if models_registry: + print(f"\nModels in registry: {len(models_registry)}") + for model in models_registry: + print(f" - {model.get('name')} (shorthand: {model.get('shorthand')})") + + print("=" * 40 + "\n") + + +def print_evaluation_config_summary(config: Dict[str, Any]) -> None: + """Print a summary of the evaluation pipeline configuration.""" + steps = config.get("steps", {}) + + print("\n===== OCR Evaluation Pipeline Configuration =====") + print(f"Build: {config.get('build', 'N/A')}") + print(f"Run name: {config.get('run_name', 'N/A')}") + + # Print caching and logging info + print(f"Cache enabled: {config.get('enable_cache', False)}") + print(f"Step logs enabled: {config.get('enable_step_logs', False)}") + + # Print OCR results information + load_results_params = steps.get("load_ocr_results", {}).get("parameters", {}) + artifact_name = load_results_params.get("artifact_name", "ocr_results") + artifact_version = load_results_params.get("version", "latest") + print(f"Loading OCR results from: {artifact_name} (version: {artifact_version})") + + # Print ground truth information + gt_params = steps.get("load_ground_truth_texts", {}).get("parameters", {}) + gt_folder = gt_params.get("ground_truth_folder") + if gt_folder: + print(f"Ground truth folder: {gt_folder}") + gt_files = list_available_ground_truth_files(directory=gt_folder) + print(f"Found {len(gt_files)} ground truth text files") + + gt_files = gt_params.get("ground_truth_files", []) + if gt_files: + print(f"Using {len(gt_files)} specific ground truth files") + + print("=" * 45 + "\n") + + +def print_config_summary( + config: Dict[str, Any], + is_evaluation_config: bool = False, +) -> None: + """Print a summary of the ZenML configuration.""" + if is_evaluation_config: + print_evaluation_config_summary(config) + else: + print_batch_config_summary(config) + + +def get_image_paths(directory: str) -> List[str]: + """Get all image paths from a directory.""" + image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"] + image_paths = [] + + for ext in image_extensions: + image_paths.extend(glob.glob(os.path.join(directory, ext))) + + return sorted(image_paths) + + +def list_available_ground_truth_files(directory: Optional[str] = "ground_truth_texts") -> List[str]: + """List available ground truth text files in the given directory. + + Args: + directory: Directory containing ground truth text files + + Returns: + List of paths to available ground truth text files + """ + if not directory or not os.path.isdir(directory): + return [] + + text_files = glob.glob(os.path.join(directory, "*.txt")) + return sorted(text_files) + + +def select_config_path(evaluation_mode: bool) -> str: + """Select the appropriate configuration file path based on the pipeline mode. + + Args: + evaluation_mode: Whether to use evaluation pipeline configuration + + Returns: + Path to the configuration file + """ + if evaluation_mode: + return "configs/evaluation_pipeline.yaml" + else: + return "configs/batch_pipeline.yaml" diff --git a/omni-reader/utils/encode_image.py b/omni-reader/utils/encode_image.py new file mode 100644 index 00000000..2420a55f --- /dev/null +++ b/omni-reader/utils/encode_image.py @@ -0,0 +1,74 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains utility functions for encoding images to base64 strings.""" + +import base64 +import mimetypes +from io import BytesIO + +from PIL import Image + + +def encode_pil_image(image: Image.Image, format: str = "JPEG") -> str: + """Encode a PIL Image object to a base64 string. + + Args: + image: PIL Image object + format: Image format for encoding (default: JPEG) + + Returns: + str: Base64 encoded string of the image + """ + buffered = BytesIO() + image.save(buffered, format=format) + image_data = buffered.getvalue() + image_base64 = base64.b64encode(image_data).decode("utf-8") + return image_base64 + + +def encode_image_from_path(image_path: str) -> str: + """Encode an image from a file path to a base64 string. + + Args: + image_path: Path to the image file + + Returns: + str: Base64 encoded string of the image + """ + with open(image_path, "rb") as image_file: + image_data = image_file.read() + image_base64 = base64.b64encode(image_data).decode("utf-8") + return image_base64 + + +def encode_image(image: Image.Image | str) -> tuple[str, str]: + """Encode an image to a base64 string. + + Args: + image: Either a PIL Image object or a string path to an image file + + Returns: + tuple[str, str]: Image type and base64 encoded string + """ + if isinstance(image, str): + content_type = mimetypes.guess_type(image)[0] or "image/jpeg" + image_base64 = encode_image_from_path(image) + else: + img_format = image.format or "JPEG" + content_type = f"image/{img_format.lower()}" if img_format else "image/jpeg" + image_base64 = encode_pil_image(image, format=img_format if img_format else "JPEG") + + return content_type, image_base64 diff --git a/omni-reader/utils/extract_json.py b/omni-reader/utils/extract_json.py new file mode 100644 index 00000000..afdf4cef --- /dev/null +++ b/omni-reader/utils/extract_json.py @@ -0,0 +1,88 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for extracting JSON from a response, handling various formats.""" + +import json +import re +from typing import Any, Dict + + +def _ensure_string_raw_text(result: Dict) -> Dict: + """Ensure raw_text in result is a string and not a list. + + Args: + result: Dictionary that may contain a raw_text field + + Returns: + Dictionary with raw_text converted to string if needed + """ + if "raw_text" in result and isinstance(result["raw_text"], list): + result["raw_text"] = "\n".join(result["raw_text"]) + return result + + +def try_extract_json_from_response(response: Any) -> Dict: + """Extract JSON from a response, handling various formats efficiently. + + Args: + response: The response which could be a string, dict, or an object with content. + + Returns: + Dict with extracted data. + """ + if isinstance(response, dict) and "raw_text" in response: + return _ensure_string_raw_text(response) + + response_text = "" + if hasattr(response, "choices") and response.choices: + msg = getattr(response.choices[0], "message", None) + if msg and hasattr(msg, "content"): + response_text = msg.content + elif isinstance(response, str): + response_text = response + elif hasattr(response, "raw_text"): + raw_text = response.raw_text + result = {"raw_text": raw_text, "confidence": getattr(response, "confidence", None)} + return _ensure_string_raw_text(result) + + response_text = response_text.strip() + + try: + parsed = json.loads(response_text) + if isinstance(parsed, dict): + return _ensure_string_raw_text(parsed) + except json.JSONDecodeError: + pass + + json_block = re.search(r"```json\s*(.*?)\s*```", response_text, re.DOTALL) + if json_block: + json_str = json_block.group(1).strip() + try: + parsed = json.loads(json_str) + return _ensure_string_raw_text(parsed) + except json.JSONDecodeError: + pass + + json_substring = re.search(r"\{.*\}", response_text, re.DOTALL) + if json_substring: + json_str = json_substring.group(0).strip() + try: + parsed = json.loads(json_str) + return _ensure_string_raw_text(parsed) + except json.JSONDecodeError: + pass + + return {"raw_text": response_text, "confidence": None} diff --git a/omni-reader/utils/metrics.py b/omni-reader/utils/metrics.py new file mode 100644 index 00000000..96a82890 --- /dev/null +++ b/omni-reader/utils/metrics.py @@ -0,0 +1,287 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains detailed error analysis and metrics for OCR results.""" + +import difflib +import re +from collections import Counter +from dataclasses import dataclass +from difflib import SequenceMatcher +from typing import Any, Dict, List, Union + +from jiwer import cer, wer + + +@dataclass +class ErrorAnalysis: + """Detailed error analysis results.""" + + total_errors: int + insertions: int + deletions: int + substitutions: int + common_substitutions: Dict[str, str] # actual -> predicted pairs + error_positions: Dict[str, int] # position categories -> counts + error_distribution: Dict[str, float] # percentages of error types + + +def analyze_errors(ground_truth: str, predicted: str) -> ErrorAnalysis: + """Perform detailed error analysis between ground truth and prediction. + + Args: + ground_truth: The reference text (ground truth) + predicted: The OCR extracted text to analyze + + Returns: + ErrorAnalysis object with detailed error metrics + """ + # Clean up texts for comparison + ground_truth = re.sub(r"\s+", " ", ground_truth).strip() + predicted = re.sub(r"\s+", " ", predicted).strip() + + # Get character-level diff + d = difflib.SequenceMatcher(None, ground_truth, predicted) + + # Track error types + insertions = 0 + deletions = 0 + substitutions = 0 + substitution_pairs = [] + error_positions = Counter() + + # Process diff blocks + for tag, i1, i2, j1, j2 in d.get_opcodes(): + if tag == "replace": + substitutions += max(i2 - i1, j2 - j1) + # Track character pairs for substitution analysis + if i2 - i1 == j2 - j1: # 1:1 substitution + for idx in range(i2 - i1): + substitution_pairs.append((ground_truth[i1 + idx], predicted[j1 + idx])) + # Track position in text (beginning, middle, end) + if i1 < len(ground_truth) * 0.2: + error_positions["beginning"] += 1 + elif i1 > len(ground_truth) * 0.8: + error_positions["end"] += 1 + else: + error_positions["middle"] += 1 + elif tag == "delete": + deletions += i2 - i1 + elif tag == "insert": + insertions += j2 - j1 + + # Analyze common substitutions + common_subs = Counter(substitution_pairs).most_common(10) + common_substitutions = {} + for (gt, pred), count in common_subs: + common_substitutions[gt] = pred + + # Calculate total errors and distribution + total_errors = insertions + deletions + substitutions + + # Calculate error distribution (percentages) + dist = {} + if total_errors > 0: + dist = { + "insertions": insertions / total_errors * 100, + "deletions": deletions / total_errors * 100, + "substitutions": substitutions / total_errors * 100, + } + + return ErrorAnalysis( + total_errors=total_errors, + insertions=insertions, + deletions=deletions, + substitutions=substitutions, + common_substitutions=common_substitutions, + error_positions=dict(error_positions), + error_distribution=dist, + ) + + +def levenshtein_ratio(s1: str, s2: str) -> float: + """Calculate the Levenshtein ratio between two strings.""" + return SequenceMatcher(None, s1, s2).ratio() + + +def find_best_model( + model_metrics: Dict[str, Dict[str, float]], + metric: str, + lower_is_better: bool = True, +) -> str: + """Find the best performing model(s) for a given metric, showing ties when they occur.""" + best_models = [] + best_value = None + + for model, metrics in model_metrics.items(): + if metric in metrics: + value = metrics[metric] + if ( + best_value is None + or (lower_is_better and value < best_value) + or (not lower_is_better and value > best_value) + ): + best_value = value + if best_value is not None: + for model, metrics in model_metrics.items(): + if metric in metrics: + value = metrics[metric] + if (lower_is_better and abs(value - best_value) < 1e-6) or ( + not lower_is_better and abs(value - best_value) < 1e-6 + ): + best_models.append(model) + + if not best_models: + return "N/A" + elif len(best_models) == 1: + return best_models[0] + else: + # return ties as a comma-separated list + return ", ".join(best_models) + + +def calculate_custom_metrics( + ground_truth_text: str, + model_texts: Dict[str, str], + model_displays: List[str], +) -> Dict[str, Dict[str, float]]: + """Calculate metrics for each model and between model pairs.""" + all_metrics = {} + model_pairs = [] + for i, model1 in enumerate(model_displays): + if model1 not in all_metrics: + all_metrics[model1] = {} + text1 = model_texts.get(model1, "") + if ground_truth_text: + all_metrics[model1]["CER"] = cer(ground_truth_text, text1) + all_metrics[model1]["WER"] = wer(ground_truth_text, text1) + all_metrics[model1]["GT Similarity"] = levenshtein_ratio(ground_truth_text, text1) + for j, model2 in enumerate(model_displays): + if i < j: + model_pairs.append((model1, model2)) + for model1, model2 in model_pairs: + text1 = model_texts.get(model1, "") + text2 = model_texts.get(model2, "") + similarity = levenshtein_ratio(text1, text2) + pair_key = f"{model1}_{model2}" + all_metrics[pair_key] = similarity + return all_metrics + + +def compare_multi_model( + ground_truth: str, + model_texts: Dict[str, str], +) -> Dict[str, Dict[str, Union[float, int, Dict]]]: + """Compares OCR results from multiple models with the ground truth. + + Args: + ground_truth (str): The ground truth text. + model_texts (Dict[str, str]): Dictionary mapping model display names to extracted text. + + Returns: + Dict[str, Dict[str, Union[float, int, Dict]]]: A dictionary of model names to metrics. + """ + results = {} + + for model_display, text in model_texts.items(): + model_metrics = {} + + model_metrics["CER"] = cer(ground_truth, text) + model_metrics["WER"] = wer(ground_truth, text) + + model_analysis = analyze_errors(ground_truth, text) + + model_metrics.update( + { + "Insertions": model_analysis.insertions, + "Deletions": model_analysis.deletions, + "Substitutions": model_analysis.substitutions, + "Insertion Rate": model_analysis.error_distribution.get("insertions", 0), + "Deletion Rate": model_analysis.error_distribution.get("deletions", 0), + "Substitution Rate": model_analysis.error_distribution.get("substitutions", 0), + "Error Positions": model_analysis.error_positions, + "Common Substitutions": model_analysis.common_substitutions, + } + ) + + results[model_display] = model_metrics + + return results + + +def normalize_text(s: str) -> str: + """Normalize text for comparison.""" + s = s.lower() + s = re.sub(r"\s+", " ", s).strip() + s = s.replace("\n", " ") + # Normalize apostrophes and similar characters + s = re.sub(r"[''′`]", "'", s) + return s + + +def calculate_model_similarities( + results: List[Dict[str, Any]], model_displays: List[str] +) -> Dict[str, float]: + """Calculate the average pairwise Levenshtein ratio between model outputs. + + Expects each result to have keys formatted as: + "raw_text_{model_display}" + where model_display is converted to lowercase and spaces are replaced with underscores. + + Args: + results (List[Dict[str, Any]]): List of dictionaries containing model outputs. + model_displays (List[str]): List of model display names. + + Returns: + Dict[str, float]: Dictionary mapping each model pair (formatted as "Model1_Model2") + to their average similarity score. + """ + similarity_sums = {} + similarity_counts = {} + + for result in results: + # Map model display names to their corresponding text + model_texts = {} + for display in model_displays: + key = f"raw_text_{display.lower().replace(' ', '_')}" + text = result.get(key, "") + if isinstance(text, str): + text = normalize_text(text) + if text: + model_texts[display] = text + + # Only proceed if at least two models have valid text + if len(model_texts) < 2: + continue + + # Compute pairwise similarity for each combination + for i in range(len(model_displays)): + for j in range(i + 1, len(model_displays)): + model1 = model_displays[i] + model2 = model_displays[j] + if model1 not in model_texts or model2 not in model_texts: + continue + text1 = model_texts[model1] + text2 = model_texts[model2] + similarity = levenshtein_ratio(text1, text2) + pair_key = f"{model1}_{model2}" + similarity_sums[pair_key] = similarity_sums.get(pair_key, 0) + similarity + similarity_counts[pair_key] = similarity_counts.get(pair_key, 0) + 1 + + # Average the similarities for each pair + similarities = { + pair: similarity_sums[pair] / similarity_counts[pair] for pair in similarity_sums + } + return similarities diff --git a/omni-reader/utils/model_configs.py b/omni-reader/utils/model_configs.py new file mode 100644 index 00000000..2769af78 --- /dev/null +++ b/omni-reader/utils/model_configs.py @@ -0,0 +1,191 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model configuration utilities for OCR operations.""" + +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple + +from utils.config import load_config + + +@dataclass +class ModelConfig: + """Configuration for OCR models.""" + + name: str + ocr_processor: str + provider: Optional[str] = None + shorthand: Optional[str] = None + display: Optional[str] = None + prefix: Optional[str] = None + logo: Optional[str] = None + additional_params: Dict[str, Any] = field(default_factory=dict) + default_confidence: float = 0.5 + + +class ModelRegistry: + """Registry for OCR model configurations.""" + + def __init__(self, config_path: str = "configs/batch_pipeline.yaml"): + """Initialize the model registry from configuration YAML.""" + self.models = {} + self.default_model = None + self.load_from_config(config_path) + + def load_from_config(self, config_path: str) -> None: + """Load model registry from configuration file.""" + config = load_config(config_path) + + # Process models from the registry + if "models_registry" in config: + for model_entry in config["models_registry"]: + model_config = ModelConfig(**model_entry) + self._infer_missing_properties(model_config) + + # Add to registry by name and shorthand + self.models[model_config.name] = model_config + if model_config.shorthand: + self.models[model_config.shorthand] = model_config + + # Process selected models list and set default + if "ocr" in config and "selected_models" in config["ocr"]: + selected = config["ocr"]["selected_models"] + if selected and selected[0] in self.models: + self.default_model = self.models[selected[0]] + + # Fallback for default model + if not self.default_model and self.models: + self.default_model = next(iter(self.models.values())) + + def _infer_missing_properties(self, model_config: ModelConfig) -> None: + """Fill in missing properties based on model name patterns.""" + if not model_config.display: + model_config.display = self._generate_display_name(model_config.name) + + if not model_config.prefix: + model_config.prefix = self._generate_prefix(model_config.display) + + if not model_config.logo: + model_config.logo = self._infer_logo(model_config.name) + + def _infer_logo(self, model_name: str) -> str: + """Infer the logo based on the model name.""" + model_name = model_name.lower() + + if any(n in model_name for n in ["gpt", "openai"]): + return "openai.svg" + elif any(n in model_name for n in ["mistral", "pixtral"]): + return "mistral.svg" + elif "gemma" in model_name: + return "gemma.svg" + elif "llava" in model_name: + return "microsoft.svg" + elif any(n in model_name for n in ["moondream", "phi", "granite"]): + return "ollama.svg" + + return "default.svg" + + def _generate_display_name(self, model_name: str) -> str: + """Generate a human-readable display name.""" + if "/" in model_name: + model_name = model_name.split("/")[1] + + parts = re.split(r"[-_:.]", model_name) + + formatted = [] + for part in parts: + if re.match(r"^\d+b$", part.lower()): # Size (7b, 11b) + formatted.append(part.upper()) + elif re.match(r"^\d+(\.\d+)*$", part): # Version numbers + formatted.append(part) + elif part.lower() in ["gpt", "llm"]: + formatted.append(part.upper()) + else: + formatted.append(part.capitalize()) + + return " ".join(formatted) + + def _generate_prefix(self, display_name: str) -> str: + """Generate a file prefix from display name.""" + prefix = display_name.lower().replace(" ", "_").replace("-", "_") + prefix = re.sub(r"[^a-z0-9_]", "", prefix) + return prefix + + def get_model(self, model_id: str) -> Optional[ModelConfig]: + """Get a model configuration by ID or shorthand.""" + return self.models.get(model_id) + + def get_model_by_prefix(self, prefix: str) -> Optional[ModelConfig]: + """Get a model configuration by its prefix.""" + for model in self.models.values(): + if model.prefix == prefix: + return model + return None + + +def get_model_info(model_name: str) -> Tuple[str, str]: + """Returns a tuple (display, prefix) for a given model name. + + Args: + model_name: The name of the model + + Returns: + A tuple (display, prefix) + """ + model = model_registry.get_model(model_name) + if model: + return model.display, model.prefix + + # Generate fallback values + if "/" in model_name: + model_part = model_name.split("/")[-1] + if ":" in model_part: + display = model_part.split(":")[0] + else: + display = model_part + display = display.replace("-", " ").title() + else: + display = model_name.split("-")[0] + if ":" in display: + display = display.split(":")[0] + display = display.title() + + prefix = display.lower().replace(" ", "_").replace("-", "_") + return display, prefix + + +def get_model_prefix(model_name: str) -> str: + """Get standardized prefix from model name.""" + if "/" in model_name: + model_name = model_name.split("/")[1] + + if ":" in model_name: + model_name = model_name.replace(":", "_") + + prefix = model_name.lower().replace("-", "_").replace(".", "_") + prefix = re.sub(r"[^a-z0-9_]", "", prefix) + return prefix + + +# global instance of the ModelRegistry +model_registry = ModelRegistry() + +# Export the registry's models dict for compatibility +MODEL_CONFIGS = model_registry.models + +# Export the default model for compatibility +DEFAULT_MODEL = model_registry.default_model diff --git a/omni-reader/utils/ocr_processing.py b/omni-reader/utils/ocr_processing.py new file mode 100644 index 00000000..4e35383f --- /dev/null +++ b/omni-reader/utils/ocr_processing.py @@ -0,0 +1,590 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for OCR operations across different models.""" + +import os +import statistics +import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import polars as pl +import requests +from dotenv import load_dotenv +from tqdm import tqdm +from zenml import log_metadata +from zenml.logger import get_logger + +from utils.encode_image import encode_image +from utils.extract_json import try_extract_json_from_response +from utils.model_configs import MODEL_CONFIGS, ModelConfig +from utils.prompt import ImageDescription, get_prompt + +load_dotenv() + +logger = get_logger(__name__) + + +# ============================================================================ +# Metadata Logging Functions +# ============================================================================ + + +def log_image_metadata( + prefix: str, + index: int, + image_name: str, + processing_time: float, + text_length: int, + confidence: float, +): + """Log metadata for a processed image.""" + log_metadata( + metadata={ + f"{prefix}_image_{index}": { + "image_name": image_name, + "processing_time_seconds": processing_time, + "text_length": text_length, + "confidence": confidence, + } + } + ) + + +def log_error_metadata( + prefix: str, + index: int, + image_name: str, + error: str, +): + """Log error metadata for a failed image processing.""" + log_metadata( + metadata={ + f"{prefix}_error_image_{index}": { + "image_name": image_name, + "error": error, + } + } + ) + + +def log_summary_metadata( + prefix: str, + model_name: str, + images_count: int, + processing_times: List[float], + confidence_scores: List[float], +): + """Log summary metadata for all processed images.""" + avg_time = statistics.mean(processing_times) if processing_times else 0 + max_time = max(processing_times) if processing_times else 0 + min_time = min(processing_times) if processing_times else 0 + avg_confidence = statistics.mean(confidence_scores) if confidence_scores else 0 + + log_metadata( + metadata={ + f"{prefix}_ocr_summary": { + "model": model_name, + "images_processed": images_count, + "avg_processing_time": avg_time, + "min_processing_time": min_time, + "max_processing_time": max_time, + "avg_confidence": avg_confidence, + "total_processing_time": sum(processing_times), + } + } + ) + + +# ============================================================================ +# Model Processing Functions +# ============================================================================ + + +def process_ollama_based(model_name: str, prompt: str, image_base64: str) -> Dict[str, Any]: + """Process an image with an Ollama model.""" + BASE_URL = os.getenv("OLLAMA_HOST") or "http://localhost:11434/api/generate" + + payload = { + "model": model_name, + "prompt": prompt, + "stream": False, + "images": [image_base64], + } + + try: + response = requests.post( + BASE_URL, + json=payload, + timeout=120, # 2mins, in case of really complex images + ) + response.raise_for_status() + res = response.json().get("response", "") + result_json = try_extract_json_from_response(res) + + return result_json + except Exception as e: + logger.error(f"Error processing with Ollama model {model_name}: {str(e)}") + return {"raw_text": f"Error: {str(e)}", "confidence": 0.0} + + +def process_openai_based(model_name: str, prompt: str, image_url: str) -> Dict[str, Any]: + """Process an image with an API-based model (OpenAI).""" + import instructor + from openai import OpenAI + + openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + client = instructor.from_openai(openai_client) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + } + ] + + try: + response = client.chat.completions.create( + model=model_name, + messages=messages, + response_model=ImageDescription, + temperature=0.0, + ) + + result_json = try_extract_json_from_response(response) + return result_json + except Exception as e: + logger.error(f"Error processing with {model_name}: {str(e)}") + return {"raw_text": f"Error: {str(e)}", "confidence": 0.0} + + +def process_litellm_based(model_config: ModelConfig, prompt: str, image_url: str) -> Dict[str, Any]: + """Process an image with a Litellm model.""" + from litellm import completion + + os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + { + "type": "image_url", + "image_url": image_url, + }, + ], + }, + ] + + try: + response = completion( + model=model_config.name, + messages=messages, + custom_llm_provider=model_config.provider, + temperature=0.0, + ) + + result_text = response["choices"][0]["message"]["content"] + ocr_result = try_extract_json_from_response(result_text) + return ocr_result + except Exception as e: + logger.error(f"Error processing with {model_config.name}: {str(e)}") + return {"raw_text": f"Error: {str(e)}", "confidence": 0.0} + + +# ============================================================================ +# Core Processing Functions +# ============================================================================ + + +def process_image( + model_config: ModelConfig, prompt: str, image_base64: str, content_type: str = "image/jpeg" +) -> Dict[str, Any]: + """Process an image with the specified model.""" + model_name = model_config.name + image_url = f"data:{content_type};base64,{image_base64}" + + processors = { + "litellm": lambda: process_litellm_based(model_config, prompt, image_url), + "ollama": lambda: process_ollama_based(model_name, prompt, image_base64), + "openai": lambda: process_openai_based(model_name, prompt, image_url), + } + + processor = processors.get(model_config.ocr_processor) + if not processor: + raise ValueError(f"Unsupported ocr_processor: {model_config.ocr_processor}") + + return processor() + + +def process_single_image( + model_id: str, + image: str, + custom_prompt: Optional[str] = None, + track_metadata: bool = True, + index: Optional[int] = None, +) -> Tuple[str, Dict[str, Any], Optional[int]]: + """Process a single image with a specific model. + + Args: + model_id: Model ID + image: Image path or PIL Image + custom_prompt: Optional custom prompt + track_metadata: Whether to track metadata with ZenML + index: Optional index for batch processing + + Returns: + Tuple of (model_id, result_dict, index) + """ + start_time = time.time() + + try: + model_config = MODEL_CONFIGS[model_id] + content_type, image_base64 = encode_image(image) + prompt = custom_prompt if custom_prompt else get_prompt() + + result_json = process_image(model_config, prompt, image_base64, content_type) + + processing_time = time.time() - start_time + if "processing_time" not in result_json: + result_json["processing_time"] = processing_time + + result_json["model"] = model_id + result_json["display_name"] = model_config.display + result_json["ocr_processor"] = model_config.ocr_processor + + return model_id, result_json, index + + except Exception as e: + processing_time = time.time() - start_time + error_result = { + "raw_text": f"Error: {str(e)}", + "error": str(e), + "processing_time": processing_time, + "model": model_id, + "display_name": MODEL_CONFIGS[model_id].display + if model_id in MODEL_CONFIGS + else model_id, + "ocr_processor": MODEL_CONFIGS[model_id].ocr_processor + if model_id in MODEL_CONFIGS + else "unknown", + } + + return model_id, error_result, index + + +def process_single_model_task(args): + """Wrapper function for ThreadPoolExecutor to unpack arguments.""" + return process_single_image(*args) + + +def process_result_and_track_metrics( + model_config: ModelConfig, + result: Dict[str, Any], + index: int, + images: List[str], + results_list: List[Dict[str, Any]], + processing_times: List[float], + confidence_scores: List[float], + track_metadata: bool = True, +): + """Process a result, log metrics, and track statistics.""" + prefix = model_config.prefix + display = model_config.display + + image_name = os.path.basename(images[index]) + processing_time = result.get("processing_time", 0) + + formatted_result = { + "id": index, + "image_name": image_name, + "raw_text": result.get("raw_text", "No text found"), + "processing_time": processing_time, + "confidence": result.get("confidence", model_config.default_confidence), + } + + if "error" in result: + formatted_result["error"] = result["error"] + + if track_metadata: + log_error_metadata( + prefix=prefix, + index=index, + image_name=image_name, + error=result["error"], + ) + else: + confidence = formatted_result["confidence"] + if confidence is None: + confidence = model_config.default_confidence + formatted_result["confidence"] = confidence + + confidence_scores.append(confidence) + + text_length = len(formatted_result["raw_text"]) + + if track_metadata: + log_image_metadata( + prefix=prefix, + index=index, + image_name=image_name, + processing_time=processing_time, + text_length=text_length, + confidence=confidence, + ) + + logger.info( + f"{display} OCR [{index + 1}/{len(images)}]: {image_name} - " + f"{text_length} chars, " + f"confidence: {confidence:.2f}, " + f"{processing_time:.2f} seconds" + ) + + results_list.append(formatted_result) + processing_times.append(processing_time) + + +# ============================================================================ +# Public API Functions +# ============================================================================ + + +def process_models_parallel( + image_input: Union[str, List[str]], + model_ids: List[str], + custom_prompt: Optional[str] = None, + max_workers: int = 5, + track_metadata: bool = True, +) -> Dict[str, Any]: + """Process image(s) with multiple models in parallel. + + Args: + image_input: Either a single image (path/PIL) or a list of image paths + model_ids: List of model IDs to process + custom_prompt: Optional custom prompt + max_workers: Maximum number of parallel workers + track_metadata: Whether to track metadata with ZenML + + Returns: + Dictionary mapping model IDs to their results + """ + effective_workers = min(len(model_ids), max_workers) + is_single_image = not isinstance(image_input, list) + results = {} + + tasks = [] + if is_single_image: + # For a single image, create tasks for each model + for model_id in model_ids: + tasks.append((model_id, image_input, custom_prompt, track_metadata, None)) + else: + # For multiple images, create tasks for each image/model combination + for index, image in enumerate(image_input): + for model_id in model_ids: + tasks.append((model_id, image, custom_prompt, track_metadata, index)) + + with ThreadPoolExecutor(max_workers=effective_workers) as executor: + futures = list(executor.map(process_single_model_task, tasks)) + + # Process results + for model_id, result, index in futures: + if is_single_image: + # For single image, just store the result + results[model_id] = result + else: + # For multiple images, group results by model + if model_id not in results: + results[model_id] = [] + results[model_id].append((index, result)) + + # For multiple images with multiple models, sort results by index + if not is_single_image: + for model_id in results: + results[model_id] = [r for _, r in sorted(results[model_id], key=lambda x: x[0])] + + return results + + +def process_images_with_model( + model_config: ModelConfig, + images: List[str], + custom_prompt: Optional[str] = None, + batch_size: int = 5, + track_metadata: bool = True, +) -> pl.DataFrame: + """Process multiple images with a specific model configuration. + + Args: + model_config: Model configuration + images: List of image paths + custom_prompt: Optional custom prompt + batch_size: Number of images to process in parallel + track_metadata: Whether to track metadata with ZenML + + Returns: + DataFrame with OCR results + """ + model_name = model_config.name + prefix = model_config.prefix + display = model_config.display + + logger.info(f"Running {display} OCR with model: {model_name}") + logger.info(f"Processing {len(images)} images with batch size: {batch_size}") + + # Track processing metrics + results_list = [] + processing_times = [] + confidence_scores = [] + + # Process images in batches to control memory usage + effective_batch_size = min(batch_size, len(images)) + + with tqdm(total=len(images), desc=f"Processing with {display}") as pbar: + for batch_start in range(0, len(images), effective_batch_size): + batch_end = min(batch_start + effective_batch_size, len(images)) + batch = images[batch_start:batch_end] + + logger.info( + f"Processing batch {batch_start // effective_batch_size + 1}/" + f"{(len(images) + effective_batch_size - 1) // effective_batch_size} " + f"with {len(batch)} images" + ) + + batch_results = process_models_parallel( + image_input=batch, + model_ids=[model_name], + custom_prompt=custom_prompt, + max_workers=min(effective_batch_size, 10), + track_metadata=track_metadata, + ) + + if model_name in batch_results: + for i, result in enumerate(batch_results[model_name]): + actual_index = batch_start + i + process_result_and_track_metrics( + model_config=model_config, + result=result, + index=actual_index, + images=images, + results_list=results_list, + processing_times=processing_times, + confidence_scores=confidence_scores, + track_metadata=track_metadata, + ) + pbar.update(1) + + if track_metadata: + log_summary_metadata( + prefix=prefix, + model_name=model_name, + images_count=len(images), + processing_times=processing_times, + confidence_scores=confidence_scores, + ) + + results_df = pl.DataFrame(results_list) + return results_df + + +def run_ocr( + image_input: Union[str, List[str]], + model_ids: Union[str, List[str]], + custom_prompt: Optional[str] = None, + batch_size: int = 5, + track_metadata: bool = False, +) -> Union[Dict[str, Any], pl.DataFrame, Dict[str, pl.DataFrame]]: + """Unified interface for running OCR on images with different modes. + + This function handles different combinations of inputs: + - Single image + single model + - Single image + multiple models + - Multiple images + single model + - Multiple images + multiple models + + Args: + image_input: Single image path/object or list of image paths + model_ids: Single model ID or list of model IDs + custom_prompt: Optional custom prompt + batch_size: Batch size for parallel processing + track_metadata: Whether to track metadata with ZenML + + Returns: + - Single image + single model: Dict result + - Single image + multiple models: Dict mapping model IDs to results + - Multiple images + single model: DataFrame with results + - Multiple images + multiple models: Dict mapping model IDs to DataFrames + """ + is_single_image = not isinstance(image_input, list) + is_single_model = not isinstance(model_ids, list) + + if is_single_model: + model_ids = [model_ids] + + if is_single_image and is_single_model: + # Single image + single model + _, result, _ = process_single_image( + model_id=model_ids[0], + image=image_input, + custom_prompt=custom_prompt, + track_metadata=track_metadata, + ) + return result + + elif is_single_image and not is_single_model: + # Single image + multiple models + return process_models_parallel( + image_input=image_input, + model_ids=model_ids, + custom_prompt=custom_prompt, + max_workers=min(len(model_ids), 10), + track_metadata=track_metadata, + ) + + elif not is_single_image and is_single_model: + # Multiple images + single model + model_config = MODEL_CONFIGS[model_ids[0]] + return process_images_with_model( + model_config=model_config, + images=image_input, + custom_prompt=custom_prompt, + batch_size=batch_size, + track_metadata=track_metadata, + ) + + else: + # Multiple images + multiple models + results = {} + for model_id in model_ids: + model_config = MODEL_CONFIGS[model_id] + results[model_id] = process_images_with_model( + model_config=model_config, + images=image_input, + custom_prompt=custom_prompt, + batch_size=batch_size, + track_metadata=track_metadata, + ) + return results diff --git a/omni-reader/utils/prompt.py b/omni-reader/utils/prompt.py new file mode 100644 index 00000000..e1ce3fb6 --- /dev/null +++ b/omni-reader/utils/prompt.py @@ -0,0 +1,40 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the prompt and schema for the OCR model.""" + +from typing import Optional + +from pydantic import BaseModel + + +class ImageDescription(BaseModel): + """Base model for OCR results.""" + + raw_text: str + confidence: Optional[float] = None + + +def get_prompt(custom_prompt: Optional[str] = None) -> str: + """Default prompt for the OCR model.""" + if custom_prompt: + return custom_prompt + return """Extract all visible text from this image **without any changes**. + - Retain all spacing, punctuation, and formatting exactly as in the image. + - If text is unclear or ambiguous (e.g., handwritten, blurry), use best judgment to **make educated guesses based on visual context** + - Return your response as a JSON object with the following fields: + - raw_text: The extracted text from the image + - confidence: The confidence score in the extracted text as a float between 0 and 1 + """ diff --git a/omni-reader/utils/visualizations.py b/omni-reader/utils/visualizations.py new file mode 100644 index 00000000..d54b6f02 --- /dev/null +++ b/omni-reader/utils/visualizations.py @@ -0,0 +1,608 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the functions for creating the HTML visualizations.""" + +import base64 +import os +from typing import Any, Dict, List + +import polars as pl +from zenml import get_step_context +from zenml.logger import get_logger +from zenml.types import HTMLString + +from utils.metrics import find_best_model +from utils.model_configs import MODEL_CONFIGS + +logger = get_logger(__name__) + + +def load_svg_logo(logo_name: str) -> str: + """Load an SVG logo as base64 encoded string.""" + logo_path = os.path.join("./assets/logos", logo_name) + try: + if os.path.exists(logo_path): + with open(logo_path, "rb") as f: + return base64.b64encode(f.read()).decode() + except Exception: + pass + return "" + + +def create_metrics_table(metrics: Dict[str, float]) -> str: + """Create an HTML table for metrics.""" + rows = "" + for key, value in metrics.items(): + rows += f""" + + {key} + {value:.4f} + + """ + table = f""" + + + + + + + + + {rows} + +
MetricValue
+ """ + return table + + +def create_comparison_table( + models: List[str], metric_data: Dict[str, Dict[str, float]], metric_names: List[str] +) -> str: + """Create an HTML table comparing metrics across models.""" + headers = "".join([f'{model}' for model in models]) + rows = "" + for metric in metric_names: + row_data = "" + best_value = None + best_index = -1 + for i, model in enumerate(models): + value = metric_data[model].get(metric, 0) + if "CER" in metric or "WER" in metric: # Lower is better + if best_value is None or value < best_value: + best_value = value + best_index = i + else: # Higher is better (e.g. similarity) + if best_value is None or value > best_value: + best_value = value + best_index = i + for i, model in enumerate(models): + value = metric_data[model].get(metric, 0) + highlight = "bg-green-100" if i == best_index else "" + row_data += f'{value:.4f}' + rows += f""" + + {metric} + {row_data} + + """ + table = f""" + + + + + {headers} + + + + {rows} + +
Metric
+ """ + return table + + +def create_model_card_with_logo(model_display: str, metrics: Dict[str, float]) -> str: + """Create a card for a model with its logo and metrics.""" + logo_path = None + for _, config in MODEL_CONFIGS.items(): + if config.display == model_display: + logo_path = config.logo + break + if not logo_path or not os.path.exists(os.path.join("./assets/logos", logo_path)): + logo_path = "default.svg" + logo_b64 = load_svg_logo(logo_path) + logo_html = f'{model_display} logo' + metrics_rows = "" + metric_keys = ["CER", "WER", "GT Similarity", "Proc. Time"] + for key in metric_keys: + if key in metrics: + metrics_rows += f""" +
+
{key}:
+
{metrics[key]:.4f}
+
+ """ + card = f""" +
+

+ {logo_html} {model_display} Metrics +

+ {metrics_rows} +
+ """ + return card + + +def create_model_comparison_card( + image_name: str, + ground_truth: str, + model_texts: Dict[str, str], + model_metrics: Dict[str, Dict[str, Any]], +) -> str: + """Create a card for comparing OCR results for a specific image across models. + + Args: + image_name: Name of the image + ground_truth: Ground truth text + model_texts: Dictionary mapping model names to their extracted text + model_metrics: Dictionary of metrics for each model + + Returns: + HTML card as a string + """ + model_sections = "" + num_models = len(model_texts) + cols_class = "grid-cols-1" + if num_models <= 3: + cols_class = f"grid-cols-1 md:grid-cols-{num_models + 1}" # +1 for ground truth + else: + cols_per_row = min(3, (num_models + 1) // 2) + cols_class = f"grid-cols-1 md:grid-cols-{cols_per_row}" + for model_display, text in model_texts.items(): + logo_path = None + for _, config in MODEL_CONFIGS.items(): + if config.display == model_display: + logo_path = config.logo + break + if not logo_path: + logo_path = "default.svg" + logo_b64 = load_svg_logo(logo_path) + logo_html = f'{model_display} logo' + model_sections += f""" +
+

+ {logo_html} {model_display} Output +

+
{text}
+
+ """ + metrics_table = create_comparison_table( + list(model_texts.keys()), model_metrics, ["CER", "WER", "GT Similarity"] + ) + error_sections = "" + error_cols = min(3, len(model_texts)) + for model_display, metrics in model_metrics.items(): + error_sections += f""" +
+

{model_display} Errors

+
    +
  • Insertions: {metrics.get("Insertions", 0)} ({metrics.get("Insertion Rate", 0):.1f}%)
  • +
  • Deletions: {metrics.get("Deletions", 0)} ({metrics.get("Deletion Rate", 0):.1f}%)
  • +
  • Substitutions: {metrics.get("Substitutions", 0)} ({metrics.get("Substitution Rate", 0):.1f}%)
  • +
+
+ """ + + # Simple header for ground truth text files + ground_truth_header = '

📄 Ground Truth

' + + card = f""" +
+

{image_name}

+ +
+
+ {ground_truth_header} +
{ground_truth}
+
+ + {model_sections} +
+ +
+

📊 Key Metrics

+ {metrics_table} +
+ +
+ {error_sections} +
+
+ """ + return card + + +def create_model_similarity_matrix(models: List[str], similarities: Dict[str, float]) -> str: + """Create a matrix showing similarity between model outputs.""" + headers = "".join([f'{model}' for model in models]) + rows = "" + for i, model1 in enumerate(models): + row_cells = "" + for j, model2 in enumerate(models): + if i == j: + row_cells += '1.0000' + continue + key1 = f"{model1}_{model2}" + key2 = f"{model2}_{model1}" + similarity = similarities.get(key1, similarities.get(key2, 0)) + row_cells += f'{similarity:.4f}' + rows += f""" + + {model1} + {row_cells} + + """ + table = f""" +
+

📈 Model Similarity Matrix

+

Higher values indicate more similar outputs between models

+ + + + + {headers} + + + + {rows} + +
+
+ """ + return table + + +def create_summary_visualization( + model_metrics: Dict[str, Dict[str, float]], + time_comparison: Dict[str, Any], + similarities: Dict[str, float] = None, +) -> HTMLString: + """Create an HTML visualization of evaluation results for multiple models.""" + step_context = get_step_context() + pipeline_run_name = step_context.pipeline_run.name + models = list(model_metrics.keys()) + + model_cards = "" + cols_per_row = min(3, len(models)) + for model_display, metrics in model_metrics.items(): + prefix = None + for _, config in MODEL_CONFIGS.items(): + if config.display == model_display: + prefix = config.prefix + break + if not prefix: + prefix = model_display.lower().replace(" ", "_") + time_key = f"avg_{prefix}_time" + if time_key in time_comparison: + metrics["Proc. Time"] = time_comparison[time_key] + model_cards += create_model_card_with_logo(model_display, metrics) + + fastest_model = time_comparison["fastest_model"] + + best_cer = find_best_model(model_metrics, "CER", lower_is_better=True) + best_wer = find_best_model(model_metrics, "WER", lower_is_better=True) + best_similarity = find_best_model(model_metrics, "GT Similarity", lower_is_better=False) + + metrics_grid = f""" +
+ {model_cards} +
+

+ 🏆 Overall Best +

+
+
Fastest Model:
+
{fastest_model}
+
Best CER:
+
{best_cer}
+
Best WER:
+
{best_wer}
+
Best Similarity:
+
{best_similarity}
+
+
+
+ """ + + comparison_metrics = ["CER", "WER", "GT Similarity"] + comparison_table = create_comparison_table(models, model_metrics, comparison_metrics) + similarity_matrix = "" + if similarities and len(models) > 1: + similarity_matrix = create_model_similarity_matrix(models, similarities) + metrics_section = f""" +
+

📊 OCR Model Performance Metrics

+ {metrics_grid} + +
+

📈 Model Comparison

+ {comparison_table} + {similarity_matrix} +
+
+ """ + html = f""" + + + + + + OCR Model Evaluation Results + + + + +
+
+

🔍 OCR Model Evaluation Dashboard

+

Pipeline Run: {pipeline_run_name}

+
+ {metrics_section} +
+ + + """ + return HTMLString(html) + + +def create_ocr_batch_visualization(df: pl.DataFrame) -> HTMLString: + """Create an HTML visualization of batch OCR processing results.""" + # Extract metrics + total_results = len(df) + # Ensure all raw_text values are strings + raw_texts = [] + for txt in df["raw_text"].to_list(): + if isinstance(txt, list): + raw_texts.append("\n".join(txt)) + else: + raw_texts.append(str(txt)) + total_chars = sum(len(txt) for txt in raw_texts) + avg_conf = df["confidence"].mean() if "confidence" in df.columns else 0 + total_proc_time = df["processing_time"].sum() if "processing_time" in df.columns else 0 + avg_proc_time = df["processing_time"].mean() if "processing_time" in df.columns else 0 + + # Get model-specific metrics + model_metrics = {} + model_displays = [] + + if "model_name" in df.columns: + for model in df["model_name"].unique().to_list(): + mdf = df.filter(pl.col("model_name") == model) + # Ensure all model-specific raw_text values are strings + m_raw_texts = [] + for txt in mdf["raw_text"].to_list(): + if isinstance(txt, list): + m_raw_texts.append("\n".join(txt)) + else: + m_raw_texts.append(str(txt)) + m_chars = sum(len(txt) for txt in m_raw_texts) + m_conf = mdf["confidence"].mean() if "confidence" in mdf.columns else 0 + m_total_time = mdf["processing_time"].sum() if "processing_time" in mdf.columns else 0 + m_avg_time = mdf["processing_time"].mean() if "processing_time" in mdf.columns else 0 + + model_metrics[model] = { + "total_images": len(mdf), + "total_chars": m_chars, + "avg_confidence": m_conf, + "total_time": m_total_time, + "avg_time": m_avg_time, + "char_per_second": m_chars / m_total_time if m_total_time > 0 else 0, + } + model_displays.append(model) + + # Create model cards HTML + model_cards = "" + if model_displays: + cols_per_row = min(3, len(model_displays)) + for model in model_displays: + metrics = model_metrics[model] + + # Try to get the model logo if available + logo_path = None + logo_html = "" + try: + for _, config in MODEL_CONFIGS.items(): + if config.display == model: + logo_path = config.logo + break + if logo_path and os.path.exists(os.path.join("./assets/logos", logo_path)): + logo_b64 = load_svg_logo(logo_path) + logo_html = f'{model} logo' + except Exception as e: + logger.warning(f"Error loading logo for {model}: {e}") + # Default to just the model name without logo if there's an issue + pass + + model_cards += f""" +
+

+ {logo_html} {model} +

+
+
Images:
+
{metrics["total_images"]}
+
Characters:
+
{metrics["total_chars"]}
+
Avg Confidence:
+
{metrics["avg_confidence"]:.2f}
+
Total Time:
+
{metrics["total_time"]:.2f}s
+
Avg Time/Image:
+
{metrics["avg_time"]:.2f}s
+
Chars/Second:
+
{metrics["char_per_second"]:.1f}
+
+
+ """ + + model_grid = f""" +
+ {model_cards} +
+ """ + else: + # Single model view + model_grid = f""" +
+
+

+ OCR Processing Summary +

+
+
Total Images:
+
{total_results}
+
Total Characters:
+
{total_chars}
+
Avg Confidence:
+
{avg_conf:.2f}
+
Total Process Time:
+
{total_proc_time:.2f}s
+
Avg Time/Image:
+
{avg_proc_time:.2f}s
+
Chars/Second:
+
{total_chars / total_proc_time if total_proc_time > 0 else 0:.1f}
+
+
+
+ """ + + # Create results table with sample data + sample_size = min(10, total_results) # Show up to 10 samples + + # Create table HTML + table_rows = "" + sample_df = df.head(sample_size) + + for i in range(sample_df.height): + row = sample_df.row(i, named=True) + model_col = ( + f'{row["model_name"]}' if "model_name" in df.columns else "" + ) + + # Ensure raw_text is a string and limit displayed text length + raw_text = row["raw_text"] + if isinstance(raw_text, list): + raw_text = "\n".join(raw_text) + text_preview = str(raw_text)[:100] + ("..." if len(str(raw_text)) > 100 else "") + + # Calculate the length properly + text_length = len(str(raw_text)) if raw_text is not None else 0 + + table_rows += f""" + + {row["image_name"]} + {model_col} + {text_length} + {row.get("confidence", 0):.2f} + {row.get("processing_time", 0):.2f}s + {text_preview} + + """ + + model_header = 'Model' if "model_name" in df.columns else "" + + results_table = f""" +
+

Sample Results ({sample_size} of {total_results})

+ + + + + {model_header} + + + + + + + + {table_rows} + +
ImageCharsConfidenceTime (s)Text Preview
+
+ """ + + # Final HTML + html = f""" + + + + + + OCR Batch Processing Results + + + + +
+
+

📝 OCR Batch Processing Results

+

Processed {total_results} images with {total_chars} total characters in {total_proc_time:.2f}s

+
+ +
+

Processing Summary

+ {model_grid} + {results_table} +
+
+ + + """ + + return HTMLString(html)