Skip to content

Commit 58ac407

Browse files
author
marwan37
committed
update steps for multi-model ocr changes
1 parent 73c3aba commit 58ac407

File tree

4 files changed

+197
-123
lines changed

4 files changed

+197
-123
lines changed

omni-reader/steps/evaluate_models.py

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def create_summary_visualization(
235235
<div class="text-right font-medium">{avg_metrics["avg_models_similarity"]:.4f}</div>
236236
<div class="text-gray-600">Time Diff:</div>
237237
<div class="text-right font-medium">{time_comparison["time_difference"]:.2f}s</div>
238-
<div class="text-gray-600">Faster Model:</div>
239-
<div class="text-right font-medium">{time_comparison["faster_model"]}</div>
238+
<div class="text-gray-600">Fastest Model:</div>
239+
<div class="text-right font-medium">{time_comparison["fastest_model"]}</div>
240240
<div class="text-gray-600">Better CER:</div>
241241
<div class="text-right font-medium">
242242
{model1_display if avg_metrics[f"avg_{model1_prefix}_cer"] < avg_metrics[f"avg_{model2_prefix}_cer"] else model2_display}
@@ -267,8 +267,8 @@ def create_summary_visualization(
267267
<div class="text-xl font-bold">{time_comparison["time_difference"]:.2f}s</div>
268268
</div>
269269
<div>
270-
<div class="text-gray-600 mb-1">Faster Model</div>
271-
<div class="text-xl font-bold">{time_comparison["faster_model"]}</div>
270+
<div class="text-gray-600 mb-1">Fastest Model</div>
271+
<div class="text-xl font-bold">{time_comparison["fastest_model"]}</div>
272272
</div>
273273
</div>
274274
</div>
@@ -325,34 +325,62 @@ def create_summary_visualization(
325325

326326
@step()
327327
def evaluate_models(
328-
model1_df: pl.DataFrame,
329-
model2_df: pl.DataFrame,
328+
model_results: Dict[str, pl.DataFrame],
330329
ground_truth_df: Optional[pl.DataFrame] = None,
331-
model1_name: str = "ollama/gemma3:27b",
332-
model2_name: str = "pixtral-12b-2409",
330+
primary_models: Optional[List[str]] = None,
333331
) -> Annotated[HTMLString, "ocr_visualization"]:
334-
"""Compare the performance of two configurable models with visualization.
332+
"""Compare the performance of multiple configurable models with visualization.
335333
336334
Args:
337-
model1_df: First model results DataFrame
338-
model2_df: Second model results DataFrame
335+
model_results: Dictionary mapping model names to their results DataFrames
339336
ground_truth_df: Optional ground truth results DataFrame
340-
model1_name: Name of the first model (default: ollama/gemma3:27b)
341-
model2_name: Name of the second model (default: pixtral-12b-2409)
342-
model1_display: Display name for the first model (default: Gemma)
343-
model2_display: Display name for the second model (default: Mistral)
337+
primary_models: Optional list of model names to highlight in comparison.
338+
If None or less than 2 models, uses the first two models from model_results.
344339
345340
Returns:
346341
HTMLString visualization of the results
347342
"""
343+
# Ensure we have at least two models for comparison
344+
if len(model_results) < 2:
345+
raise ValueError("At least two models are required for comparison")
346+
347+
# If primary_models not specified or invalid, use the first two models
348+
if not primary_models or len(primary_models) < 2:
349+
primary_models = list(model_results.keys())[:2]
350+
351+
# Extract the primary models for main comparison
352+
model1_name = primary_models[0]
353+
model2_name = primary_models[1]
354+
355+
model1_df = model_results[model1_name]
356+
model2_df = model_results[model2_name]
357+
348358
model1_display, model1_prefix = get_model_info(model1_name)
349359
model2_display, model2_prefix = get_model_info(model2_name)
350360

351361
# Join results
352-
results = model1_df.join(model2_df, on=["id", "image_name"], how="inner")
362+
results = model1_df.join(model2_df, on=["id", "image_name"], how="inner", suffix="_right")
353363
evaluation_metrics = []
354364
processed_results = []
355365

366+
# Calculate processing times for all models
367+
all_model_times = {}
368+
for model_name, df in model_results.items():
369+
display, prefix = get_model_info(model_name)
370+
time_key = f"avg_{prefix}_time"
371+
all_model_times[time_key] = df.select("processing_time").to_series().mean()
372+
all_model_times[f"{prefix}_display"] = display
373+
374+
# Find fastest model
375+
fastest_model_time = min(
376+
[(time, model) for model, time in all_model_times.items() if not model.endswith("_display")]
377+
)
378+
fastest_model_key = fastest_model_time[1]
379+
fastest_model_prefix = fastest_model_key.replace("avg_", "").replace("_time", "")
380+
fastest_model_display = all_model_times.get(
381+
f"{fastest_model_prefix}_display", fastest_model_prefix
382+
)
383+
356384
if ground_truth_df is not None:
357385
results = results.join(
358386
ground_truth_df,
@@ -412,38 +440,49 @@ def evaluate_models(
412440
].mean(),
413441
}
414442

415-
model1_times = model1_df.select("processing_time").to_series().mean()
416-
model2_times = model2_df.select("processing_time").to_series().mean()
417443
model1_time_key = f"avg_{model1_prefix}_time"
418444
model2_time_key = f"avg_{model2_prefix}_time"
445+
446+
# Combine processing times with other metrics
419447
time_comparison = {
420-
model1_time_key: model1_times,
421-
model2_time_key: model2_times,
422-
"time_difference": abs(model1_times - model2_times),
423-
"faster_model": model1_display if model1_times < model2_times else model2_display,
448+
**all_model_times,
449+
"time_difference": abs(
450+
all_model_times[model1_time_key] - all_model_times[model2_time_key]
451+
),
452+
"fastest_model": fastest_model_display,
424453
}
425454

426-
# Log metadata for ZenML dashboard
427-
log_metadata(
428-
metadata={
455+
# Prepare metadata for ZenML dashboard
456+
metadata_dict = {
457+
**{
458+
f"avg_{model}_time": float(time)
459+
for model, time in all_model_times.items()
460+
if not model.endswith("_display")
461+
},
462+
"fastest_model": fastest_model_display,
463+
"model_count": len(model_results),
464+
"avg_models_similarity": float(avg_metrics["avg_models_similarity"]),
465+
}
466+
467+
# Add accuracy metrics for primary models
468+
metadata_dict.update(
469+
{
429470
f"avg_{model1_prefix}_cer": float(avg_metrics[f"avg_{model1_prefix}_cer"]),
430471
f"avg_{model1_prefix}_wer": float(avg_metrics[f"avg_{model1_prefix}_wer"]),
431472
f"avg_{model2_prefix}_cer": float(avg_metrics[f"avg_{model2_prefix}_cer"]),
432473
f"avg_{model2_prefix}_wer": float(avg_metrics[f"avg_{model2_prefix}_wer"]),
433-
"avg_models_similarity": float(avg_metrics["avg_models_similarity"]),
434474
f"avg_{model1_prefix}_gt_similarity": float(
435475
avg_metrics[f"avg_{model1_prefix}_gt_similarity"]
436476
),
437477
f"avg_{model2_prefix}_gt_similarity": float(
438478
avg_metrics[f"avg_{model2_prefix}_gt_similarity"]
439479
),
440-
model1_time_key: float(time_comparison[model1_time_key]),
441-
model2_time_key: float(time_comparison[model2_time_key]),
442-
"time_difference": float(time_comparison["time_difference"]),
443-
"faster_model": time_comparison["faster_model"],
444480
}
445481
)
446482

483+
# Log metadata for ZenML dashboard
484+
log_metadata(metadata=metadata_dict)
485+
447486
html_visualization = create_summary_visualization(
448487
avg_metrics=avg_metrics,
449488
time_comparison=time_comparison,
@@ -456,30 +495,33 @@ def evaluate_models(
456495
return html_visualization
457496

458497
# FALLBACK: if no ground truth metrics, only use processing times.
459-
model1_times = model1_df.select("processing_time").to_series().mean()
460-
model2_times = model2_df.select("processing_time").to_series().mean()
461-
model1_time_key = f"avg_{model1_prefix}_time"
462-
model2_time_key = f"avg_{model2_prefix}_time"
463498
time_comparison = {
464-
model1_time_key: model1_times,
465-
model2_time_key: model2_times,
466-
"time_difference": abs(model1_times - model2_times),
467-
"faster_model": model1_display if model1_times < model2_times else model2_display,
499+
**all_model_times,
500+
"time_difference": abs(
501+
all_model_times[f"avg_{model1_prefix}_time"]
502+
- all_model_times[f"avg_{model2_prefix}_time"]
503+
),
504+
"fastest_model": fastest_model_display,
468505
}
506+
469507
html_visualization = create_summary_visualization(
470508
avg_metrics={},
471509
time_comparison=time_comparison,
472510
model1_name=model1_name,
473511
model2_name=model2_name,
474512
)
475513

476-
log_metadata(
477-
metadata={
478-
model1_time_key: float(time_comparison[model1_time_key]),
479-
model2_time_key: float(time_comparison[model2_time_key]),
480-
"time_difference": float(time_comparison["time_difference"]),
481-
"faster_model": time_comparison["faster_model"],
482-
}
483-
)
514+
# Prepare metadata for ZenML dashboard
515+
metadata_dict = {
516+
**{
517+
f"avg_{model}_time": float(time)
518+
for model, time in all_model_times.items()
519+
if not model.endswith("_display")
520+
},
521+
"fastest_model": fastest_model_display,
522+
"model_count": len(model_results),
523+
}
524+
525+
log_metadata(metadata=metadata_dict)
484526

485527
return html_visualization

omni-reader/steps/load_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ def load_images(
9797
@step
9898
def load_ground_truth_file(
9999
filepath: str,
100-
) -> Annotated[Dict[str, pl.DataFrame], "ground_truth"]:
100+
) -> Annotated[pl.DataFrame, "ground_truth"]:
101101
"""Load ground truth data from a JSON file.
102102
103103
Args:
104104
filepath: Path to the ground truth file
105105
106106
Returns:
107-
Dictionary containing ground truth results
107+
pl.DataFrame containing ground truth results
108108
"""
109109
from utils.io_utils import load_ocr_data_from_json
110110

@@ -115,4 +115,4 @@ def load_ground_truth_file(
115115

116116
log_metadata(metadata={"ground_truth_loaded": {"path": filepath, "count": len(df)}})
117117

118-
return {"ground_truth_results": df}
118+
return df

omni-reader/steps/run_ocr.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,53 +15,86 @@
1515
# limitations under the License.
1616
"""This module contains a unified OCR step that works with multiple models."""
1717

18-
from typing import List, Optional
18+
import os
19+
from typing import Dict, List, Optional
1920

2021
import polars as pl
2122
from typing_extensions import Annotated
2223
from zenml import step
2324
from zenml.logger import get_logger
2425

25-
from utils import (
26-
MODEL_CONFIGS,
27-
process_images_with_model,
28-
)
26+
from utils.model_configs import MODEL_CONFIGS
27+
from utils.ocr_processing import process_images_with_model
2928

3029
logger = get_logger(__name__)
3130

3231

33-
@step(enable_cache=False)
32+
@step()
3433
def run_ocr(
35-
images: List[str], model_name: str, custom_prompt: Optional[str] = None
36-
) -> Annotated[pl.DataFrame, "ocr_results"]:
37-
"""Extract text from images using the specified model.
34+
images: List[str], model_names: List[str], custom_prompt: Optional[str] = None
35+
) -> Annotated[Dict[str, pl.DataFrame], "ocr_results"]:
36+
"""Extract text from images using multiple models in parallel.
3837
3938
Args:
4039
images: List of paths to image files
41-
model_name: Name of the model to use (e.g., "gpt-4o-mini", "ollama/gemma3:27b", "pixtral-12b-2409")
40+
model_names: List of model names to use
4241
custom_prompt: Optional custom prompt to override the default prompt
4342
4443
Returns:
45-
Dict: Containing results dataframe with OCR results
44+
Dict: Mapping of model name to results dataframe with OCR results
4645
4746
Raises:
48-
ValueError: If the model_name is not supported
47+
ValueError: If any model_name is not supported
4948
"""
50-
if model_name not in MODEL_CONFIGS:
51-
supported_models = ", ".join(MODEL_CONFIGS.keys())
52-
raise ValueError(
53-
f"Unsupported model: {model_name}. Supported models are: {supported_models}"
54-
)
49+
from concurrent.futures import ThreadPoolExecutor
5550

56-
model_config = MODEL_CONFIGS[model_name]
51+
from tqdm import tqdm
5752

58-
logger.info(f"Running OCR with model: {model_name}")
53+
# Validate all models
54+
for model_name in model_names:
55+
if model_name not in MODEL_CONFIGS:
56+
supported_models = ", ".join(MODEL_CONFIGS.keys())
57+
raise ValueError(
58+
f"Unsupported model: {model_name}. Supported models are: {supported_models}"
59+
)
60+
61+
logger.info(f"Running OCR with {len(model_names)} models: {', '.join(model_names)}")
5962
logger.info(f"Processing {len(images)} images")
6063

61-
results_df = process_images_with_model(
62-
model_config=model_config,
63-
images=images,
64-
custom_prompt=custom_prompt,
65-
)
64+
results = {}
65+
66+
with ThreadPoolExecutor(max_workers=min(len(model_names), 5)) as executor:
67+
futures = {
68+
model_name: executor.submit(
69+
process_images_with_model,
70+
model_config=MODEL_CONFIGS[model_name],
71+
images=images,
72+
custom_prompt=custom_prompt,
73+
)
74+
for model_name in model_names
75+
}
76+
77+
with tqdm(total=len(model_names), desc="Processing models") as pbar:
78+
for model_name, future in futures.items():
79+
try:
80+
results_df = future.result()
81+
results[model_name] = results_df
82+
logger.info(f"Completed processing with model: {model_name}")
83+
except Exception as e:
84+
logger.error(f"Error processing model {model_name}: {str(e)}")
85+
# empty dataframe with error message to avoid pipeline failure
86+
results[model_name] = pl.DataFrame(
87+
{
88+
"id": range(len(images)),
89+
"image_name": [os.path.basename(img) for img in images],
90+
"raw_text": [f"Error processing with {model_name}: {str(e)}"]
91+
* len(images),
92+
"processing_time": [0.0] * len(images),
93+
"confidence": [0.0] * len(images),
94+
"error": [str(e)] * len(images),
95+
}
96+
)
97+
finally:
98+
pbar.update(1)
6699

67-
return results_df
100+
return results

0 commit comments

Comments
 (0)