1515# limitations under the License.
1616"""OCR Comparison Pipeline implementation with YAML configuration support."""
1717
18+ import os
1819from typing import Any , Dict , List , Literal , Optional
1920
20- from zenml import pipeline
21+ from dotenv import load_dotenv
22+ from zenml import pipeline , step
2123from zenml .config import DockerSettings
2224from zenml .logger import get_logger
2325
3032 save_visualization ,
3133)
3234
35+ load_dotenv ()
36+
3337docker_settings = DockerSettings (
3438 dockerfile = "Dockerfile" ,
3539 requirements = [
4448 "ollama==0.4.7" ,
4549 "pyarrow>=7.0" ,
4650 ],
51+ environment = {
52+ "OLLAMA_HOST" : "${OLLAMA_HOST:-http://localhost:11434}" ,
53+ "OLLAMA_MODELS" : "/root/.ollama" ,
54+ "OLLAMA_TIMEOUT" : "600s" ,
55+ "MISTRAL_API_KEY" : "${MISTRAL_API_KEY}" ,
56+ "OPENAI_API_KEY" : "${OPENAI_API_KEY}" ,
57+ },
4758)
4859
4960logger = get_logger (__name__ )
5061
5162
63+ @step
64+ def extract_ground_truth_df (ground_truth_results , ground_truth_model ):
65+ """Extract ground truth DataFrame from the results dictionary.
66+
67+ Args:
68+ ground_truth_results: Dictionary with model results returned by run_ocr
69+ ground_truth_model: Name of the ground truth model
70+
71+ Returns:
72+ The ground truth DataFrame
73+ """
74+ if ground_truth_model in ground_truth_results :
75+ return ground_truth_results [ground_truth_model ]
76+ return None
77+
78+
5279@pipeline (settings = {"docker" : docker_settings })
5380def ocr_comparison_pipeline (
5481 image_paths : Optional [List [str ]] = None ,
5582 image_folder : Optional [str ] = None ,
5683 custom_prompt : Optional [str ] = None ,
57- model1 : str = "ollama/gemma3:27b" ,
58- model2 : str = "pixtral-12b-2409" ,
84+ models : Optional [List [str ]] = None ,
5985 ground_truth_model : str = "gpt-4o-mini" ,
6086 ground_truth_source : Literal ["openai" , "manual" , "file" , "none" ] = "none" ,
6187 ground_truth_file : Optional [str ] = None ,
@@ -66,14 +92,13 @@ def ocr_comparison_pipeline(
6692 save_visualization_data : bool = False ,
6793 visualization_output_dir : str = "visualizations" ,
6894) -> None :
69- """Run OCR comparison pipeline between two configurable models.
95+ """Run OCR comparison pipeline between multiple configurable models.
7096
7197 Args:
7298 image_paths: Optional list of specific image paths to process
7399 image_folder: Optional folder to search for images
74100 custom_prompt: Optional custom prompt to use for both models
75- model1: Name of the first model to use (default: ollama/gemma3:27b)
76- model2: Name of the second model to use (default: pixtral-12b-2409)
101+ models: List of model names to use (if None, uses default models)
77102 ground_truth_model: Name of the model to use for ground truth when source is "openai"
78103 ground_truth_source: Source of ground truth - "openai" to use configured model, "manual" for user-provided texts,
79104 "file" to load from a saved JSON file, or "none" to skip ground truth evaluation
@@ -89,43 +114,46 @@ def ocr_comparison_pipeline(
89114 None
90115 """
91116 images = load_images (image_paths = image_paths , image_folder = image_folder )
92- model_names = []
93-
94- model1_results = run_ocr (images = images , model_name = model1 , custom_prompt = custom_prompt )
95- model_names .append (model1 )
96117
97- model2_results = run_ocr (images = images , model_name = model2 , custom_prompt = custom_prompt )
98- model_names .append (model2 )
118+ # Default to two models if none provided
119+ if not models or len (models ) < 2 :
120+ models = ["llama3.2-vision:11b" , "pixtral-12b-2409" ]
99121
100- # Handle ground truth based on the selected source
101- ground_truth = None
102- openai_results = None
122+ # Process all models in parallel
123+ model_results = run_ocr (images = images , model_names = models , custom_prompt = custom_prompt )
103124
125+ # Process ground truth separately to avoid including it in the main comparison
126+ ground_truth_df = None
104127 if ground_truth_source == "openai" :
105- openai_results = run_ocr (
106- images = images , model_name = ground_truth_model , custom_prompt = custom_prompt
128+ # Run OCR on the ground truth model
129+ ground_truth_results = run_ocr (
130+ images = images , model_names = [ground_truth_model ], custom_prompt = custom_prompt
131+ )
132+
133+ # Extract the ground truth DataFrame from the results dictionary
134+ ground_truth_df = extract_ground_truth_df (
135+ ground_truth_results = ground_truth_results , ground_truth_model = ground_truth_model
107136 )
108- ground_truth = openai_results
109- model_names .append (ground_truth_model )
137+
138+ models .append (ground_truth_model )
110139 elif ground_truth_source == "file" and ground_truth_file :
111- ground_truth = load_ground_truth_file (filepath = ground_truth_file )
140+ ground_truth_df = load_ground_truth_file (filepath = ground_truth_file )
141+
142+ # Select the first two models as primary for visualization (maintaining backward compatibility)
143+ primary_models = models [:2 ]
112144
113- # Evaluate models
114145 visualization = evaluate_models (
115- model1_df = model1_results ,
116- model2_df = model2_results ,
117- ground_truth_df = ground_truth ,
118- model1_name = model1 ,
119- model2_name = model2 ,
146+ model_results = model_results ,
147+ ground_truth_df = ground_truth_df ,
148+ primary_models = primary_models ,
120149 )
121150
122151 # Save OCR results if requested
123152 if save_ocr_results_data or save_ground_truth_data :
124153 save_ocr_results (
125- model1_results = model1_results ,
126- model2_results = model2_results ,
127- ground_truth_results = ground_truth ,
128- model_names = model_names ,
154+ ocr_results = model_results ,
155+ ground_truth_results = ground_truth_df ,
156+ model_names = models ,
129157 output_dir = ocr_results_output_dir ,
130158 ground_truth_output_dir = ground_truth_output_dir ,
131159 save_ground_truth = save_ground_truth_data ,
@@ -145,12 +173,18 @@ def run_ocr_pipeline(config: Dict[str, Any]) -> None:
145173 Returns:
146174 None
147175 """
176+ models = config ["models" ].get ("models" )
177+ if not models :
178+ models = [
179+ config ["models" ].get ("model1" , "llama3.2-vision:11b" ),
180+ config ["models" ].get ("model2" , "pixtral-12b-2409" ),
181+ ]
182+
148183 ocr_comparison_pipeline (
149184 image_paths = config ["input" ].get ("image_paths" ),
150185 image_folder = config ["input" ].get ("image_folder" ),
151186 custom_prompt = config ["models" ].get ("custom_prompt" ),
152- model1 = config ["models" ].get ("model1" , "ollama/gemma3:27b" ),
153- model2 = config ["models" ].get ("model2" , "pixtral-12b-2409" ),
187+ models = models ,
154188 ground_truth_model = config ["models" ].get ("ground_truth_model" , "gpt-4o-mini" ),
155189 ground_truth_source = config ["ground_truth" ].get ("source" , "none" ),
156190 ground_truth_file = config ["ground_truth" ].get ("file" ),
0 commit comments