Skip to content

Commit 6ecc9b5

Browse files
author
marwan37
committed
update steps
1 parent 575a0d6 commit 6ecc9b5

File tree

5 files changed

+57
-250
lines changed

5 files changed

+57
-250
lines changed

omni-reader/steps/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
from .evaluate_models import evaluate_models
17-
from .load_files import (
17+
from .loaders import (
1818
load_ground_truth_file,
19+
load_ground_truth_texts,
20+
load_ocr_results,
1921
load_images,
2022
)
2123
from .run_ocr import run_ocr

omni-reader/steps/evaluate_models.py

Lines changed: 35 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import base64
44
import os
55
import re
6-
from typing import Any, Dict, List, Optional
6+
from difflib import SequenceMatcher
7+
from typing import Any, Dict, List
78

89
import polars as pl
910
from jiwer import cer, wer
10-
from textdistance import jaccard
1111
from typing_extensions import Annotated
1212
from zenml import get_step_context, log_metadata, step
1313
from zenml.types import HTMLString
@@ -16,6 +16,11 @@
1616
from utils.model_configs import MODEL_CONFIGS
1717

1818

19+
def levenshtein_ratio(s1: str, s2: str) -> float:
20+
"""Calculate the Levenshtein ratio between two strings."""
21+
return SequenceMatcher(None, s1, s2).ratio()
22+
23+
1924
def load_svg_logo(logo_name: str) -> str:
2025
"""Load an SVG logo as base64 encoded string."""
2126
logo_path = os.path.join("./assets/logos", logo_name)
@@ -137,7 +142,6 @@ def create_model_comparison_card(
137142
ground_truth: str,
138143
model_texts: Dict[str, str],
139144
model_metrics: Dict[str, Dict[str, Any]],
140-
ground_truth_model: Optional[str] = None,
141145
) -> str:
142146
"""Create a card for comparing OCR results for a specific image across models."""
143147
model_sections = ""
@@ -183,18 +187,8 @@ def create_model_comparison_card(
183187
</div>
184188
"""
185189

186-
# Add OpenAI logo to Ground Truth header if applicable
187-
ground_truth_header = '<h4 class="font-bold mb-2 text-gray-700">Ground Truth</h4>'
188-
if ground_truth_model and ground_truth_model in MODEL_CONFIGS:
189-
gt_config = MODEL_CONFIGS[ground_truth_model]
190-
if gt_config.logo:
191-
logo_b64 = load_svg_logo(gt_config.logo)
192-
ground_truth_header = f"""
193-
<h4 class="font-bold mb-2 text-gray-700 flex items-center">
194-
<img src="data:image/svg+xml;base64,{logo_b64}" width="20" class="inline mr-1" alt="{gt_config.display} logo">
195-
Ground Truth
196-
</h4>
197-
"""
190+
# Simple header for ground truth text files
191+
ground_truth_header = '<h4 class="font-bold mb-2 text-gray-700">📄 Ground Truth</h4>'
198192

199193
card = f"""
200194
<div class="bg-white rounded-lg shadow-md p-6 mb-6 border border-gray-200">
@@ -266,22 +260,12 @@ def create_summary_visualization(
266260
model_metrics: Dict[str, Dict[str, float]],
267261
time_comparison: Dict[str, Any],
268262
similarities: Dict[str, float] = None,
269-
ground_truth_model: str = None,
270263
) -> HTMLString:
271264
"""Create an HTML visualization of evaluation results for multiple models."""
272265
step_context = get_step_context()
273266
pipeline_run_name = step_context.pipeline_run.name
274267
models = list(model_metrics.keys())
275268

276-
# Exclude ground truth model from best model calculations
277-
exclude_from_best = []
278-
if ground_truth_model:
279-
for display_name in models:
280-
for model_id, config in MODEL_CONFIGS.items():
281-
if config.display == display_name and model_id == ground_truth_model:
282-
exclude_from_best.append(display_name)
283-
break
284-
285269
model_cards = ""
286270
cols_per_row = min(3, len(models))
287271
for model_display, metrics in model_metrics.items():
@@ -299,15 +283,9 @@ def create_summary_visualization(
299283

300284
fastest_model = time_comparison["fastest_model"]
301285

302-
best_cer = find_best_model(
303-
model_metrics, "CER", lower_is_better=True, exclude_model_names=exclude_from_best
304-
)
305-
best_wer = find_best_model(
306-
model_metrics, "WER", lower_is_better=True, exclude_model_names=exclude_from_best
307-
)
308-
best_similarity = find_best_model(
309-
model_metrics, "GT Similarity", lower_is_better=False, exclude_model_names=exclude_from_best
310-
)
286+
best_cer = find_best_model(model_metrics, "CER", lower_is_better=True)
287+
best_wer = find_best_model(model_metrics, "WER", lower_is_better=True)
288+
best_similarity = find_best_model(model_metrics, "GT Similarity", lower_is_better=False)
311289

312290
metrics_grid = f"""
313291
<div class="grid grid-cols-1 md:grid-cols-{cols_per_row} gap-6 mb-6">
@@ -391,6 +369,7 @@ def normalize_text(s: str) -> str:
391369
"""Normalize text for comparison."""
392370
s = s.lower()
393371
s = re.sub(r"\s+", " ", s).strip()
372+
s = s.replace("\n", " ")
394373
# Normalize apostrophes and similar characters
395374
s = re.sub(r"[''′`]", "'", s)
396375
return s
@@ -399,7 +378,7 @@ def normalize_text(s: str) -> str:
399378
def calculate_model_similarities(
400379
results: List[Dict[str, Any]], model_displays: List[str]
401380
) -> Dict[str, float]:
402-
"""Calculate the average pairwise Jaccard similarity between model outputs.
381+
"""Calculate the average pairwise Levenshtein ratio between model outputs.
403382
404383
Expects each result to have keys formatted as:
405384
"raw_text_{model_display}"
@@ -440,7 +419,7 @@ def calculate_model_similarities(
440419
continue
441420
text1 = model_texts[model1]
442421
text2 = model_texts[model2]
443-
similarity = jaccard.normalized_similarity(text1, text2)
422+
similarity = levenshtein_ratio(text1, text2)
444423
pair_key = f"{model1}_{model2}"
445424
similarity_sums[pair_key] = similarity_sums.get(pair_key, 0) + similarity
446425
similarity_counts[pair_key] = similarity_counts.get(pair_key, 0) + 1
@@ -456,16 +435,12 @@ def find_best_model(
456435
model_metrics: Dict[str, Dict[str, float]],
457436
metric: str,
458437
lower_is_better: bool = True,
459-
exclude_model_names: List[str] = None,
460438
) -> str:
461439
"""Find the best performing model(s) for a given metric, showing ties when they occur."""
462440
best_models = []
463441
best_value = None
464-
exclude_model_names = exclude_model_names or []
465442

466443
for model, metrics in model_metrics.items():
467-
if model in exclude_model_names:
468-
continue
469444
if metric in metrics:
470445
value = metrics[metric]
471446
if (
@@ -476,8 +451,6 @@ def find_best_model(
476451
best_value = value
477452
if best_value is not None:
478453
for model, metrics in model_metrics.items():
479-
if model in exclude_model_names:
480-
continue
481454
if metric in metrics:
482455
value = metrics[metric]
483456
if (lower_is_better and abs(value - best_value) < 1e-6) or (
@@ -495,7 +468,9 @@ def find_best_model(
495468

496469

497470
def calculate_custom_metrics(
498-
ground_truth_text: str, model_texts: Dict[str, str], model_displays: List[str]
471+
ground_truth_text: str,
472+
model_texts: Dict[str, str],
473+
model_displays: List[str],
499474
) -> Dict[str, Dict[str, float]]:
500475
"""Calculate metrics for each model and between model pairs."""
501476
all_metrics = {}
@@ -507,16 +482,14 @@ def calculate_custom_metrics(
507482
if ground_truth_text:
508483
all_metrics[model1]["CER"] = cer(ground_truth_text, text1)
509484
all_metrics[model1]["WER"] = wer(ground_truth_text, text1)
510-
all_metrics[model1]["GT Similarity"] = jaccard.normalized_similarity(
511-
ground_truth_text, text1
512-
)
485+
all_metrics[model1]["GT Similarity"] = levenshtein_ratio(ground_truth_text, text1)
513486
for j, model2 in enumerate(model_displays):
514487
if i < j:
515488
model_pairs.append((model1, model2))
516489
for model1, model2 in model_pairs:
517490
text1 = model_texts.get(model1, "")
518491
text2 = model_texts.get(model2, "")
519-
similarity = jaccard.normalized_similarity(text1, text2)
492+
similarity = levenshtein_ratio(text1, text2)
520493
pair_key = f"{model1}_{model2}"
521494
all_metrics[pair_key] = similarity
522495
return all_metrics
@@ -525,28 +498,24 @@ def calculate_custom_metrics(
525498
@step(enable_cache=False)
526499
def evaluate_models(
527500
model_results: Dict[str, pl.DataFrame],
528-
ground_truth_df: Optional[pl.DataFrame] = None,
529-
ground_truth_model: Optional[str] = None,
501+
ground_truth_df: pl.DataFrame,
530502
) -> Annotated[HTMLString, "ocr_visualization"]:
531503
"""Compare the performance of multiple configurable models with visualization.
532504
533-
The ground truth model is separated from evaluation models so that it is used only
534-
for displaying the reference text. All metric calculations, similarity computations,
535-
and best model indicators are performed solely on evaluation models.
505+
Args:
506+
model_results: Dictionary mapping model names to results DataFrames
507+
ground_truth_df: DataFrame containing ground truth texts
508+
509+
Returns:
510+
HTML visualization of the evaluation results
536511
"""
537512
if not model_results:
538513
raise ValueError("At least one model is required for evaluation")
539514

540-
# --- 1. Separate the ground truth model from evaluation models ---
541-
if ground_truth_model and ground_truth_model in model_results:
542-
gt_df = model_results[ground_truth_model].clone()
543-
del model_results[ground_truth_model]
544-
else:
545-
gt_df = None
515+
if ground_truth_df is None or ground_truth_df.is_empty():
516+
raise ValueError("Ground truth data is required for evaluation")
546517

547-
# If a separate ground_truth_df is provided, that overrides any GT model data.
548-
if ground_truth_df is not None:
549-
gt_df = ground_truth_df
518+
gt_df = ground_truth_df
550519

551520
# --- 2. Build model info for evaluation models ---
552521
model_keys = list(model_results.keys())
@@ -598,9 +567,10 @@ def evaluate_models(
598567
evaluation_metrics = []
599568
image_cards_html = ""
600569
gt_text_col = "ground_truth_text"
601-
if gt_text_col not in merged_results.columns:
602-
if "raw_text_gt" in merged_results.columns:
603-
gt_text_col = "raw_text_gt"
570+
571+
# Check if we have ground truth data in our joined dataset
572+
if gt_text_col not in merged_results.columns and "raw_text_gt" in merged_results.columns:
573+
gt_text_col = "raw_text_gt" # Fall back to legacy ground truth model format
604574

605575
for row in merged_results.iter_rows(named=True):
606576
if gt_text_col not in row:
@@ -637,7 +607,6 @@ def evaluate_models(
637607
ground_truth=ground_truth,
638608
model_texts=model_texts,
639609
model_metrics=error_analysis,
640-
ground_truth_model=ground_truth_model,
641610
)
642611
image_cards_html += comparison_card
643612

@@ -697,24 +666,16 @@ def evaluate_models(
697666
if tk1 in all_model_times and tk2 in all_model_times:
698667
time_comparison["time_difference"] = abs(all_model_times[tk1] - all_model_times[tk2])
699668

700-
# --- 10. Exclude GT model from final summary metrics ---
701-
if ground_truth_model:
702-
for model_id, cfg in MODEL_CONFIGS.items():
703-
if model_id == ground_truth_model and cfg.display in model_metric_averages:
704-
del model_metric_averages[cfg.display]
705-
break
706-
707669
# Log metadata (customize the metadata_dict as needed)
708670
log_metadata(metadata={"fastest_model": fastest_display, "model_count": len(model_keys)})
709671

710672
summary_html = create_summary_visualization(
711673
model_metrics=model_metric_averages,
712674
time_comparison=time_comparison,
713675
similarities=similarities,
714-
ground_truth_model=ground_truth_model,
715676
)
716677

717-
# --- 11. Combine summary and per-image details ---
678+
# --- 10. Combine summary and per-image details ---
718679
final_html = f"""
719680
{summary_html}
720681
<div class="container mx-auto px-4">

omni-reader/steps/load_files.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

0 commit comments

Comments
 (0)