33import base64
44import os
55import re
6- from typing import Any , Dict , List , Optional
6+ from difflib import SequenceMatcher
7+ from typing import Any , Dict , List
78
89import polars as pl
910from jiwer import cer , wer
10- from textdistance import jaccard
1111from typing_extensions import Annotated
1212from zenml import get_step_context , log_metadata , step
1313from zenml .types import HTMLString
1616from utils .model_configs import MODEL_CONFIGS
1717
1818
19+ def levenshtein_ratio (s1 : str , s2 : str ) -> float :
20+ """Calculate the Levenshtein ratio between two strings."""
21+ return SequenceMatcher (None , s1 , s2 ).ratio ()
22+
23+
1924def load_svg_logo (logo_name : str ) -> str :
2025 """Load an SVG logo as base64 encoded string."""
2126 logo_path = os .path .join ("./assets/logos" , logo_name )
@@ -137,7 +142,6 @@ def create_model_comparison_card(
137142 ground_truth : str ,
138143 model_texts : Dict [str , str ],
139144 model_metrics : Dict [str , Dict [str , Any ]],
140- ground_truth_model : Optional [str ] = None ,
141145) -> str :
142146 """Create a card for comparing OCR results for a specific image across models."""
143147 model_sections = ""
@@ -183,18 +187,8 @@ def create_model_comparison_card(
183187 </div>
184188 """
185189
186- # Add OpenAI logo to Ground Truth header if applicable
187- ground_truth_header = '<h4 class="font-bold mb-2 text-gray-700">Ground Truth</h4>'
188- if ground_truth_model and ground_truth_model in MODEL_CONFIGS :
189- gt_config = MODEL_CONFIGS [ground_truth_model ]
190- if gt_config .logo :
191- logo_b64 = load_svg_logo (gt_config .logo )
192- ground_truth_header = f"""
193- <h4 class="font-bold mb-2 text-gray-700 flex items-center">
194- <img src="data:image/svg+xml;base64,{ logo_b64 } " width="20" class="inline mr-1" alt="{ gt_config .display } logo">
195- Ground Truth
196- </h4>
197- """
190+ # Simple header for ground truth text files
191+ ground_truth_header = '<h4 class="font-bold mb-2 text-gray-700">📄 Ground Truth</h4>'
198192
199193 card = f"""
200194 <div class="bg-white rounded-lg shadow-md p-6 mb-6 border border-gray-200">
@@ -266,22 +260,12 @@ def create_summary_visualization(
266260 model_metrics : Dict [str , Dict [str , float ]],
267261 time_comparison : Dict [str , Any ],
268262 similarities : Dict [str , float ] = None ,
269- ground_truth_model : str = None ,
270263) -> HTMLString :
271264 """Create an HTML visualization of evaluation results for multiple models."""
272265 step_context = get_step_context ()
273266 pipeline_run_name = step_context .pipeline_run .name
274267 models = list (model_metrics .keys ())
275268
276- # Exclude ground truth model from best model calculations
277- exclude_from_best = []
278- if ground_truth_model :
279- for display_name in models :
280- for model_id , config in MODEL_CONFIGS .items ():
281- if config .display == display_name and model_id == ground_truth_model :
282- exclude_from_best .append (display_name )
283- break
284-
285269 model_cards = ""
286270 cols_per_row = min (3 , len (models ))
287271 for model_display , metrics in model_metrics .items ():
@@ -299,15 +283,9 @@ def create_summary_visualization(
299283
300284 fastest_model = time_comparison ["fastest_model" ]
301285
302- best_cer = find_best_model (
303- model_metrics , "CER" , lower_is_better = True , exclude_model_names = exclude_from_best
304- )
305- best_wer = find_best_model (
306- model_metrics , "WER" , lower_is_better = True , exclude_model_names = exclude_from_best
307- )
308- best_similarity = find_best_model (
309- model_metrics , "GT Similarity" , lower_is_better = False , exclude_model_names = exclude_from_best
310- )
286+ best_cer = find_best_model (model_metrics , "CER" , lower_is_better = True )
287+ best_wer = find_best_model (model_metrics , "WER" , lower_is_better = True )
288+ best_similarity = find_best_model (model_metrics , "GT Similarity" , lower_is_better = False )
311289
312290 metrics_grid = f"""
313291 <div class="grid grid-cols-1 md:grid-cols-{ cols_per_row } gap-6 mb-6">
@@ -391,6 +369,7 @@ def normalize_text(s: str) -> str:
391369 """Normalize text for comparison."""
392370 s = s .lower ()
393371 s = re .sub (r"\s+" , " " , s ).strip ()
372+ s = s .replace ("\n " , " " )
394373 # Normalize apostrophes and similar characters
395374 s = re .sub (r"[''′`]" , "'" , s )
396375 return s
@@ -399,7 +378,7 @@ def normalize_text(s: str) -> str:
399378def calculate_model_similarities (
400379 results : List [Dict [str , Any ]], model_displays : List [str ]
401380) -> Dict [str , float ]:
402- """Calculate the average pairwise Jaccard similarity between model outputs.
381+ """Calculate the average pairwise Levenshtein ratio between model outputs.
403382
404383 Expects each result to have keys formatted as:
405384 "raw_text_{model_display}"
@@ -440,7 +419,7 @@ def calculate_model_similarities(
440419 continue
441420 text1 = model_texts [model1 ]
442421 text2 = model_texts [model2 ]
443- similarity = jaccard . normalized_similarity (text1 , text2 )
422+ similarity = levenshtein_ratio (text1 , text2 )
444423 pair_key = f"{ model1 } _{ model2 } "
445424 similarity_sums [pair_key ] = similarity_sums .get (pair_key , 0 ) + similarity
446425 similarity_counts [pair_key ] = similarity_counts .get (pair_key , 0 ) + 1
@@ -456,16 +435,12 @@ def find_best_model(
456435 model_metrics : Dict [str , Dict [str , float ]],
457436 metric : str ,
458437 lower_is_better : bool = True ,
459- exclude_model_names : List [str ] = None ,
460438) -> str :
461439 """Find the best performing model(s) for a given metric, showing ties when they occur."""
462440 best_models = []
463441 best_value = None
464- exclude_model_names = exclude_model_names or []
465442
466443 for model , metrics in model_metrics .items ():
467- if model in exclude_model_names :
468- continue
469444 if metric in metrics :
470445 value = metrics [metric ]
471446 if (
@@ -476,8 +451,6 @@ def find_best_model(
476451 best_value = value
477452 if best_value is not None :
478453 for model , metrics in model_metrics .items ():
479- if model in exclude_model_names :
480- continue
481454 if metric in metrics :
482455 value = metrics [metric ]
483456 if (lower_is_better and abs (value - best_value ) < 1e-6 ) or (
@@ -495,7 +468,9 @@ def find_best_model(
495468
496469
497470def calculate_custom_metrics (
498- ground_truth_text : str , model_texts : Dict [str , str ], model_displays : List [str ]
471+ ground_truth_text : str ,
472+ model_texts : Dict [str , str ],
473+ model_displays : List [str ],
499474) -> Dict [str , Dict [str , float ]]:
500475 """Calculate metrics for each model and between model pairs."""
501476 all_metrics = {}
@@ -507,16 +482,14 @@ def calculate_custom_metrics(
507482 if ground_truth_text :
508483 all_metrics [model1 ]["CER" ] = cer (ground_truth_text , text1 )
509484 all_metrics [model1 ]["WER" ] = wer (ground_truth_text , text1 )
510- all_metrics [model1 ]["GT Similarity" ] = jaccard .normalized_similarity (
511- ground_truth_text , text1
512- )
485+ all_metrics [model1 ]["GT Similarity" ] = levenshtein_ratio (ground_truth_text , text1 )
513486 for j , model2 in enumerate (model_displays ):
514487 if i < j :
515488 model_pairs .append ((model1 , model2 ))
516489 for model1 , model2 in model_pairs :
517490 text1 = model_texts .get (model1 , "" )
518491 text2 = model_texts .get (model2 , "" )
519- similarity = jaccard . normalized_similarity (text1 , text2 )
492+ similarity = levenshtein_ratio (text1 , text2 )
520493 pair_key = f"{ model1 } _{ model2 } "
521494 all_metrics [pair_key ] = similarity
522495 return all_metrics
@@ -525,28 +498,24 @@ def calculate_custom_metrics(
525498@step (enable_cache = False )
526499def evaluate_models (
527500 model_results : Dict [str , pl .DataFrame ],
528- ground_truth_df : Optional [pl .DataFrame ] = None ,
529- ground_truth_model : Optional [str ] = None ,
501+ ground_truth_df : pl .DataFrame ,
530502) -> Annotated [HTMLString , "ocr_visualization" ]:
531503 """Compare the performance of multiple configurable models with visualization.
532504
533- The ground truth model is separated from evaluation models so that it is used only
534- for displaying the reference text. All metric calculations, similarity computations,
535- and best model indicators are performed solely on evaluation models.
505+ Args:
506+ model_results: Dictionary mapping model names to results DataFrames
507+ ground_truth_df: DataFrame containing ground truth texts
508+
509+ Returns:
510+ HTML visualization of the evaluation results
536511 """
537512 if not model_results :
538513 raise ValueError ("At least one model is required for evaluation" )
539514
540- # --- 1. Separate the ground truth model from evaluation models ---
541- if ground_truth_model and ground_truth_model in model_results :
542- gt_df = model_results [ground_truth_model ].clone ()
543- del model_results [ground_truth_model ]
544- else :
545- gt_df = None
515+ if ground_truth_df is None or ground_truth_df .is_empty ():
516+ raise ValueError ("Ground truth data is required for evaluation" )
546517
547- # If a separate ground_truth_df is provided, that overrides any GT model data.
548- if ground_truth_df is not None :
549- gt_df = ground_truth_df
518+ gt_df = ground_truth_df
550519
551520 # --- 2. Build model info for evaluation models ---
552521 model_keys = list (model_results .keys ())
@@ -598,9 +567,10 @@ def evaluate_models(
598567 evaluation_metrics = []
599568 image_cards_html = ""
600569 gt_text_col = "ground_truth_text"
601- if gt_text_col not in merged_results .columns :
602- if "raw_text_gt" in merged_results .columns :
603- gt_text_col = "raw_text_gt"
570+
571+ # Check if we have ground truth data in our joined dataset
572+ if gt_text_col not in merged_results .columns and "raw_text_gt" in merged_results .columns :
573+ gt_text_col = "raw_text_gt" # Fall back to legacy ground truth model format
604574
605575 for row in merged_results .iter_rows (named = True ):
606576 if gt_text_col not in row :
@@ -637,7 +607,6 @@ def evaluate_models(
637607 ground_truth = ground_truth ,
638608 model_texts = model_texts ,
639609 model_metrics = error_analysis ,
640- ground_truth_model = ground_truth_model ,
641610 )
642611 image_cards_html += comparison_card
643612
@@ -697,24 +666,16 @@ def evaluate_models(
697666 if tk1 in all_model_times and tk2 in all_model_times :
698667 time_comparison ["time_difference" ] = abs (all_model_times [tk1 ] - all_model_times [tk2 ])
699668
700- # --- 10. Exclude GT model from final summary metrics ---
701- if ground_truth_model :
702- for model_id , cfg in MODEL_CONFIGS .items ():
703- if model_id == ground_truth_model and cfg .display in model_metric_averages :
704- del model_metric_averages [cfg .display ]
705- break
706-
707669 # Log metadata (customize the metadata_dict as needed)
708670 log_metadata (metadata = {"fastest_model" : fastest_display , "model_count" : len (model_keys )})
709671
710672 summary_html = create_summary_visualization (
711673 model_metrics = model_metric_averages ,
712674 time_comparison = time_comparison ,
713675 similarities = similarities ,
714- ground_truth_model = ground_truth_model ,
715676 )
716677
717- # --- 11 . Combine summary and per-image details ---
678+ # --- 10 . Combine summary and per-image details ---
718679 final_html = f"""
719680 { summary_html }
720681 <div class="container mx-auto px-4">
0 commit comments