Skip to content

Commit 991e89f

Browse files
author
marwan37
committed
update pipeline for multi-model ocr and update docker settings
1 parent 4d96ae3 commit 991e89f

File tree

1 file changed

+66
-32
lines changed

1 file changed

+66
-32
lines changed

omni-reader/pipelines/ocr_pipeline.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# limitations under the License.
1616
"""OCR Comparison Pipeline implementation with YAML configuration support."""
1717

18+
import os
1819
from typing import Any, Dict, List, Literal, Optional
1920

20-
from zenml import pipeline
21+
from dotenv import load_dotenv
22+
from zenml import pipeline, step
2123
from zenml.config import DockerSettings
2224
from zenml.logger import get_logger
2325

@@ -30,6 +32,8 @@
3032
save_visualization,
3133
)
3234

35+
load_dotenv()
36+
3337
docker_settings = DockerSettings(
3438
dockerfile="Dockerfile",
3539
requirements=[
@@ -44,18 +48,40 @@
4448
"ollama==0.4.7",
4549
"pyarrow>=7.0",
4650
],
51+
environment={
52+
"OLLAMA_HOST": "${OLLAMA_HOST:-http://localhost:11434}",
53+
"OLLAMA_MODELS": "/root/.ollama",
54+
"OLLAMA_TIMEOUT": "600s",
55+
"MISTRAL_API_KEY": "${MISTRAL_API_KEY}",
56+
"OPENAI_API_KEY": "${OPENAI_API_KEY}",
57+
},
4758
)
4859

4960
logger = get_logger(__name__)
5061

5162

63+
@step
64+
def extract_ground_truth_df(ground_truth_results, ground_truth_model):
65+
"""Extract ground truth DataFrame from the results dictionary.
66+
67+
Args:
68+
ground_truth_results: Dictionary with model results returned by run_ocr
69+
ground_truth_model: Name of the ground truth model
70+
71+
Returns:
72+
The ground truth DataFrame
73+
"""
74+
if ground_truth_model in ground_truth_results:
75+
return ground_truth_results[ground_truth_model]
76+
return None
77+
78+
5279
@pipeline(settings={"docker": docker_settings})
5380
def ocr_comparison_pipeline(
5481
image_paths: Optional[List[str]] = None,
5582
image_folder: Optional[str] = None,
5683
custom_prompt: Optional[str] = None,
57-
model1: str = "ollama/gemma3:27b",
58-
model2: str = "pixtral-12b-2409",
84+
models: Optional[List[str]] = None,
5985
ground_truth_model: str = "gpt-4o-mini",
6086
ground_truth_source: Literal["openai", "manual", "file", "none"] = "none",
6187
ground_truth_file: Optional[str] = None,
@@ -66,14 +92,13 @@ def ocr_comparison_pipeline(
6692
save_visualization_data: bool = False,
6793
visualization_output_dir: str = "visualizations",
6894
) -> None:
69-
"""Run OCR comparison pipeline between two configurable models.
95+
"""Run OCR comparison pipeline between multiple configurable models.
7096
7197
Args:
7298
image_paths: Optional list of specific image paths to process
7399
image_folder: Optional folder to search for images
74100
custom_prompt: Optional custom prompt to use for both models
75-
model1: Name of the first model to use (default: ollama/gemma3:27b)
76-
model2: Name of the second model to use (default: pixtral-12b-2409)
101+
models: List of model names to use (if None, uses default models)
77102
ground_truth_model: Name of the model to use for ground truth when source is "openai"
78103
ground_truth_source: Source of ground truth - "openai" to use configured model, "manual" for user-provided texts,
79104
"file" to load from a saved JSON file, or "none" to skip ground truth evaluation
@@ -89,43 +114,46 @@ def ocr_comparison_pipeline(
89114
None
90115
"""
91116
images = load_images(image_paths=image_paths, image_folder=image_folder)
92-
model_names = []
93-
94-
model1_results = run_ocr(images=images, model_name=model1, custom_prompt=custom_prompt)
95-
model_names.append(model1)
96117

97-
model2_results = run_ocr(images=images, model_name=model2, custom_prompt=custom_prompt)
98-
model_names.append(model2)
118+
# Default to two models if none provided
119+
if not models or len(models) < 2:
120+
models = ["llama3.2-vision:11b", "pixtral-12b-2409"]
99121

100-
# Handle ground truth based on the selected source
101-
ground_truth = None
102-
openai_results = None
122+
# Process all models in parallel
123+
model_results = run_ocr(images=images, model_names=models, custom_prompt=custom_prompt)
103124

125+
# Process ground truth separately to avoid including it in the main comparison
126+
ground_truth_df = None
104127
if ground_truth_source == "openai":
105-
openai_results = run_ocr(
106-
images=images, model_name=ground_truth_model, custom_prompt=custom_prompt
128+
# Run OCR on the ground truth model
129+
ground_truth_results = run_ocr(
130+
images=images, model_names=[ground_truth_model], custom_prompt=custom_prompt
131+
)
132+
133+
# Extract the ground truth DataFrame from the results dictionary
134+
ground_truth_df = extract_ground_truth_df(
135+
ground_truth_results=ground_truth_results, ground_truth_model=ground_truth_model
107136
)
108-
ground_truth = openai_results
109-
model_names.append(ground_truth_model)
137+
138+
models.append(ground_truth_model)
110139
elif ground_truth_source == "file" and ground_truth_file:
111-
ground_truth = load_ground_truth_file(filepath=ground_truth_file)
140+
ground_truth_df = load_ground_truth_file(filepath=ground_truth_file)
141+
142+
# Select the first two models as primary for visualization (maintaining backward compatibility)
143+
primary_models = models[:2]
112144

113-
# Evaluate models
114145
visualization = evaluate_models(
115-
model1_df=model1_results,
116-
model2_df=model2_results,
117-
ground_truth_df=ground_truth,
118-
model1_name=model1,
119-
model2_name=model2,
146+
model_results=model_results,
147+
ground_truth_df=ground_truth_df,
148+
primary_models=primary_models,
120149
)
121150

122151
# Save OCR results if requested
123152
if save_ocr_results_data or save_ground_truth_data:
124153
save_ocr_results(
125-
model1_results=model1_results,
126-
model2_results=model2_results,
127-
ground_truth_results=ground_truth,
128-
model_names=model_names,
154+
ocr_results=model_results,
155+
ground_truth_results=ground_truth_df,
156+
model_names=models,
129157
output_dir=ocr_results_output_dir,
130158
ground_truth_output_dir=ground_truth_output_dir,
131159
save_ground_truth=save_ground_truth_data,
@@ -145,12 +173,18 @@ def run_ocr_pipeline(config: Dict[str, Any]) -> None:
145173
Returns:
146174
None
147175
"""
176+
models = config["models"].get("models")
177+
if not models:
178+
models = [
179+
config["models"].get("model1", "llama3.2-vision:11b"),
180+
config["models"].get("model2", "pixtral-12b-2409"),
181+
]
182+
148183
ocr_comparison_pipeline(
149184
image_paths=config["input"].get("image_paths"),
150185
image_folder=config["input"].get("image_folder"),
151186
custom_prompt=config["models"].get("custom_prompt"),
152-
model1=config["models"].get("model1", "ollama/gemma3:27b"),
153-
model2=config["models"].get("model2", "pixtral-12b-2409"),
187+
models=models,
154188
ground_truth_model=config["models"].get("ground_truth_model", "gpt-4o-mini"),
155189
ground_truth_source=config["ground_truth"].get("source", "none"),
156190
ground_truth_file=config["ground_truth"].get("file"),

0 commit comments

Comments
 (0)