Skip to content

Commit 8e70923

Browse files
author
marwan37
committed
separate batch ocr from evaluation into 2 pipelines
1 parent 92e3fa0 commit 8e70923

File tree

4 files changed

+262
-183
lines changed

4 files changed

+262
-183
lines changed

omni-reader/pipelines/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
"""OCR comparison pipelines package."""
17+
"""OCR pipelines package."""
1818

19-
from pipelines.ocr_pipeline import ocr_comparison_pipeline
20-
21-
__all__ = ["ocr_comparison_pipeline"]
19+
from pipelines.batch_pipeline import ocr_batch_pipeline, run_ocr_batch_pipeline
20+
from pipelines.evaluation_pipeline import ocr_evaluation_pipeline, run_ocr_evaluation_pipeline
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2025. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""OCR Batch Pipeline implementation for processing images with multiple models."""
17+
18+
from typing import Any, Dict, List, Optional
19+
20+
from dotenv import load_dotenv
21+
from zenml import pipeline
22+
from zenml.logger import get_logger
23+
24+
from steps import (
25+
load_images,
26+
run_ocr,
27+
save_ocr_results,
28+
)
29+
30+
load_dotenv()
31+
32+
logger = get_logger(__name__)
33+
34+
35+
@pipeline
36+
def ocr_batch_pipeline(
37+
image_paths: Optional[List[str]] = None,
38+
image_folder: Optional[str] = None,
39+
custom_prompt: Optional[str] = None,
40+
models: List[str] = None,
41+
save_ocr_results_data: bool = False,
42+
ocr_results_output_dir: str = "ocr_results",
43+
) -> None:
44+
"""Run OCR batch processing pipeline with multiple models.
45+
46+
Args:
47+
image_paths: Optional list of specific image paths to process
48+
image_folder: Optional folder to search for images
49+
custom_prompt: Optional custom prompt to use for the models
50+
models: List of model names to use for OCR
51+
save_ocr_results_data: Whether to save OCR results
52+
ocr_results_output_dir: Directory to save OCR results
53+
54+
Returns:
55+
None
56+
"""
57+
if not models or len(models) == 0:
58+
raise ValueError("At least one model must be specified for the batch pipeline")
59+
60+
images = load_images(
61+
image_paths=image_paths,
62+
image_folder=image_folder,
63+
)
64+
model_results = run_ocr(
65+
images=images,
66+
models=models,
67+
custom_prompt=custom_prompt,
68+
)
69+
70+
if save_ocr_results_data:
71+
save_ocr_results(
72+
ocr_results=model_results,
73+
model_names=models,
74+
output_dir=ocr_results_output_dir,
75+
)
76+
77+
78+
def run_ocr_batch_pipeline(config: Dict[str, Any]) -> None:
79+
"""Run the OCR batch pipeline from a configuration dictionary.
80+
81+
Args:
82+
config: Dictionary containing configuration
83+
84+
Returns:
85+
None
86+
"""
87+
# Check pipeline mode
88+
mode = config.get("parameters", {}).get("mode", "batch")
89+
if mode != "batch":
90+
logger.warning(f"Expected mode 'batch', but got '{mode}'. Proceeding anyway.")
91+
92+
# Get selected models from config
93+
selected_models = config.get("parameters", {}).get("selected_models", [])
94+
if not selected_models:
95+
raise ValueError(
96+
"No models selected in configuration. Add 'selected_models' to parameters section."
97+
)
98+
99+
# Create pipeline instance
100+
pipeline_instance = ocr_batch_pipeline.with_options(
101+
enable_cache=config.get("enable_cache", False),
102+
)
103+
104+
# Get params from config
105+
pipeline_params = config.get("parameters", {})
106+
pipeline_steps = config.get("steps", {})
107+
save_ocr_results_params = pipeline_steps.get("save_ocr_results", {}).get("parameters", {})
108+
109+
# Run the pipeline
110+
pipeline_instance(
111+
image_paths=pipeline_params.get("input_image_paths", []),
112+
image_folder=pipeline_params.get("input_image_folder"),
113+
custom_prompt=pipeline_steps.get("run_ocr", {}).get("parameters", {}).get("custom_prompt"),
114+
models=selected_models,
115+
save_ocr_results_data=save_ocr_results_params.get("save_locally", False),
116+
ocr_results_output_dir=save_ocr_results_params.get("output_dir", "ocr_results"),
117+
)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2025. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""OCR Evaluation Pipeline implementation for comparing models using existing results."""
17+
18+
from typing import Any, Dict, List, Optional
19+
20+
import polars as pl
21+
from dotenv import load_dotenv
22+
from zenml import pipeline, step
23+
from zenml.logger import get_logger
24+
25+
from steps import (
26+
evaluate_models,
27+
load_ground_truth_texts,
28+
load_ocr_results,
29+
save_visualization,
30+
)
31+
32+
load_dotenv()
33+
34+
logger = get_logger(__name__)
35+
36+
37+
@pipeline
38+
def ocr_evaluation_pipeline(
39+
model_names: List[str] = None,
40+
results_dir: str = "ocr_results",
41+
result_files: Optional[List[str]] = None,
42+
ground_truth_folder: Optional[str] = None,
43+
ground_truth_files: Optional[List[str]] = None,
44+
save_visualization_data: bool = False,
45+
visualization_output_dir: str = "visualizations",
46+
) -> None:
47+
"""Run OCR evaluation pipeline comparing existing model results."""
48+
if not model_names or len(model_names) < 2:
49+
raise ValueError("At least two models are required for comparison")
50+
51+
if not ground_truth_folder and not ground_truth_files:
52+
raise ValueError(
53+
"Either ground_truth_folder or ground_truth_files must be provided for evaluation"
54+
)
55+
56+
model_results = load_ocr_results(
57+
model_names=model_names,
58+
results_dir=results_dir,
59+
result_files=result_files,
60+
)
61+
62+
ground_truth_df = load_ground_truth_texts(
63+
model_results=model_results,
64+
ground_truth_folder=ground_truth_folder,
65+
ground_truth_files=ground_truth_files,
66+
)
67+
68+
visualization = evaluate_models(
69+
model_results=model_results,
70+
ground_truth_df=ground_truth_df,
71+
)
72+
73+
if save_visualization_data:
74+
save_visualization(
75+
visualization,
76+
output_dir=visualization_output_dir,
77+
)
78+
79+
80+
def run_ocr_evaluation_pipeline(config: Dict[str, Any]) -> None:
81+
"""Run the OCR evaluation pipeline from a configuration dictionary.
82+
83+
Args:
84+
config: Dictionary containing configuration
85+
86+
Returns:
87+
None
88+
"""
89+
mode = config.get("parameters", {}).get("mode", "evaluation")
90+
if mode != "evaluation":
91+
logger.warning(f"Expected mode 'evaluation', but got '{mode}'. Proceeding anyway.")
92+
93+
selected_models = config.get("parameters", {}).get("selected_models", [])
94+
if len(selected_models) < 2:
95+
raise ValueError("At least two models are required for evaluation")
96+
97+
model_registry = config.get("models_registry", [])
98+
if not model_registry:
99+
raise ValueError("models_registry is required in the config")
100+
101+
# Get model names from registry by using the passed models (may be shorthands or full names)
102+
model_names = []
103+
shorthand_to_name = {
104+
m.get("shorthand"): m.get("name") for m in model_registry if "shorthand" in m
105+
}
106+
107+
for model_id in selected_models:
108+
if model_id in shorthand_to_name:
109+
model_names.append(shorthand_to_name[model_id])
110+
else:
111+
if any(m.get("name") == model_id for m in model_registry):
112+
model_names.append(model_id)
113+
else:
114+
logger.warning(f"Model '{model_id}' not found in registry, using as-is")
115+
model_names.append(model_id)
116+
117+
if len(selected_models) < 2:
118+
raise ValueError("At least two models are required for evaluation")
119+
120+
# Set up pipeline options
121+
pipeline_instance = ocr_evaluation_pipeline.with_options(
122+
enable_cache=config.get("enable_cache", False),
123+
enable_artifact_metadata=config.get("enable_artifact_metadata", True),
124+
enable_artifact_visualization=config.get("enable_artifact_visualization", True),
125+
)
126+
127+
evaluate_models_params = (
128+
config.get("steps", {}).get("evaluate_models", {}).get("parameters", {})
129+
)
130+
save_visualization_params = (
131+
config.get("steps", {}).get("save_visualization", {}).get("parameters", {})
132+
)
133+
134+
pipeline_instance(
135+
model_names=model_names,
136+
results_dir=evaluate_models_params.get("results_dir", "ocr_results"),
137+
result_files=evaluate_models_params.get("result_files"),
138+
ground_truth_folder=evaluate_models_params.get("ground_truth_folder"),
139+
ground_truth_files=evaluate_models_params.get("ground_truth_files", []),
140+
save_visualization_data=save_visualization_params.get("save_locally", False),
141+
visualization_output_dir=save_visualization_params.get("output_dir", "visualizations"),
142+
)

0 commit comments

Comments
 (0)