@@ -235,8 +235,8 @@ def create_summary_visualization(
235235 <div class="text-right font-medium">{ avg_metrics ["avg_models_similarity" ]:.4f} </div>
236236 <div class="text-gray-600">Time Diff:</div>
237237 <div class="text-right font-medium">{ time_comparison ["time_difference" ]:.2f} s</div>
238- <div class="text-gray-600">Faster Model:</div>
239- <div class="text-right font-medium">{ time_comparison ["faster_model " ]} </div>
238+ <div class="text-gray-600">Fastest Model:</div>
239+ <div class="text-right font-medium">{ time_comparison ["fastest_model " ]} </div>
240240 <div class="text-gray-600">Better CER:</div>
241241 <div class="text-right font-medium">
242242 { model1_display if avg_metrics [f"avg_{ model1_prefix } _cer" ] < avg_metrics [f"avg_{ model2_prefix } _cer" ] else model2_display }
@@ -267,8 +267,8 @@ def create_summary_visualization(
267267 <div class="text-xl font-bold">{ time_comparison ["time_difference" ]:.2f} s</div>
268268 </div>
269269 <div>
270- <div class="text-gray-600 mb-1">Faster Model</div>
271- <div class="text-xl font-bold">{ time_comparison ["faster_model " ]} </div>
270+ <div class="text-gray-600 mb-1">Fastest Model</div>
271+ <div class="text-xl font-bold">{ time_comparison ["fastest_model " ]} </div>
272272 </div>
273273 </div>
274274 </div>
@@ -325,34 +325,62 @@ def create_summary_visualization(
325325
326326@step ()
327327def evaluate_models (
328- model1_df : pl .DataFrame ,
329- model2_df : pl .DataFrame ,
328+ model_results : Dict [str , pl .DataFrame ],
330329 ground_truth_df : Optional [pl .DataFrame ] = None ,
331- model1_name : str = "ollama/gemma3:27b" ,
332- model2_name : str = "pixtral-12b-2409" ,
330+ primary_models : Optional [List [str ]] = None ,
333331) -> Annotated [HTMLString , "ocr_visualization" ]:
334- """Compare the performance of two configurable models with visualization.
332+ """Compare the performance of multiple configurable models with visualization.
335333
336334 Args:
337- model1_df: First model results DataFrame
338- model2_df: Second model results DataFrame
335+ model_results: Dictionary mapping model names to their results DataFrames
339336 ground_truth_df: Optional ground truth results DataFrame
340- model1_name: Name of the first model (default: ollama/gemma3:27b)
341- model2_name: Name of the second model (default: pixtral-12b-2409)
342- model1_display: Display name for the first model (default: Gemma)
343- model2_display: Display name for the second model (default: Mistral)
337+ primary_models: Optional list of model names to highlight in comparison.
338+ If None or less than 2 models, uses the first two models from model_results.
344339
345340 Returns:
346341 HTMLString visualization of the results
347342 """
343+ # Ensure we have at least two models for comparison
344+ if len (model_results ) < 2 :
345+ raise ValueError ("At least two models are required for comparison" )
346+
347+ # If primary_models not specified or invalid, use the first two models
348+ if not primary_models or len (primary_models ) < 2 :
349+ primary_models = list (model_results .keys ())[:2 ]
350+
351+ # Extract the primary models for main comparison
352+ model1_name = primary_models [0 ]
353+ model2_name = primary_models [1 ]
354+
355+ model1_df = model_results [model1_name ]
356+ model2_df = model_results [model2_name ]
357+
348358 model1_display , model1_prefix = get_model_info (model1_name )
349359 model2_display , model2_prefix = get_model_info (model2_name )
350360
351361 # Join results
352- results = model1_df .join (model2_df , on = ["id" , "image_name" ], how = "inner" )
362+ results = model1_df .join (model2_df , on = ["id" , "image_name" ], how = "inner" , suffix = "_right" )
353363 evaluation_metrics = []
354364 processed_results = []
355365
366+ # Calculate processing times for all models
367+ all_model_times = {}
368+ for model_name , df in model_results .items ():
369+ display , prefix = get_model_info (model_name )
370+ time_key = f"avg_{ prefix } _time"
371+ all_model_times [time_key ] = df .select ("processing_time" ).to_series ().mean ()
372+ all_model_times [f"{ prefix } _display" ] = display
373+
374+ # Find fastest model
375+ fastest_model_time = min (
376+ [(time , model ) for model , time in all_model_times .items () if not model .endswith ("_display" )]
377+ )
378+ fastest_model_key = fastest_model_time [1 ]
379+ fastest_model_prefix = fastest_model_key .replace ("avg_" , "" ).replace ("_time" , "" )
380+ fastest_model_display = all_model_times .get (
381+ f"{ fastest_model_prefix } _display" , fastest_model_prefix
382+ )
383+
356384 if ground_truth_df is not None :
357385 results = results .join (
358386 ground_truth_df ,
@@ -412,38 +440,49 @@ def evaluate_models(
412440 ].mean (),
413441 }
414442
415- model1_times = model1_df .select ("processing_time" ).to_series ().mean ()
416- model2_times = model2_df .select ("processing_time" ).to_series ().mean ()
417443 model1_time_key = f"avg_{ model1_prefix } _time"
418444 model2_time_key = f"avg_{ model2_prefix } _time"
445+
446+ # Combine processing times with other metrics
419447 time_comparison = {
420- model1_time_key : model1_times ,
421- model2_time_key : model2_times ,
422- "time_difference" : abs (model1_times - model2_times ),
423- "faster_model" : model1_display if model1_times < model2_times else model2_display ,
448+ ** all_model_times ,
449+ "time_difference" : abs (
450+ all_model_times [model1_time_key ] - all_model_times [model2_time_key ]
451+ ),
452+ "fastest_model" : fastest_model_display ,
424453 }
425454
426- # Log metadata for ZenML dashboard
427- log_metadata (
428- metadata = {
455+ # Prepare metadata for ZenML dashboard
456+ metadata_dict = {
457+ ** {
458+ f"avg_{ model } _time" : float (time )
459+ for model , time in all_model_times .items ()
460+ if not model .endswith ("_display" )
461+ },
462+ "fastest_model" : fastest_model_display ,
463+ "model_count" : len (model_results ),
464+ "avg_models_similarity" : float (avg_metrics ["avg_models_similarity" ]),
465+ }
466+
467+ # Add accuracy metrics for primary models
468+ metadata_dict .update (
469+ {
429470 f"avg_{ model1_prefix } _cer" : float (avg_metrics [f"avg_{ model1_prefix } _cer" ]),
430471 f"avg_{ model1_prefix } _wer" : float (avg_metrics [f"avg_{ model1_prefix } _wer" ]),
431472 f"avg_{ model2_prefix } _cer" : float (avg_metrics [f"avg_{ model2_prefix } _cer" ]),
432473 f"avg_{ model2_prefix } _wer" : float (avg_metrics [f"avg_{ model2_prefix } _wer" ]),
433- "avg_models_similarity" : float (avg_metrics ["avg_models_similarity" ]),
434474 f"avg_{ model1_prefix } _gt_similarity" : float (
435475 avg_metrics [f"avg_{ model1_prefix } _gt_similarity" ]
436476 ),
437477 f"avg_{ model2_prefix } _gt_similarity" : float (
438478 avg_metrics [f"avg_{ model2_prefix } _gt_similarity" ]
439479 ),
440- model1_time_key : float (time_comparison [model1_time_key ]),
441- model2_time_key : float (time_comparison [model2_time_key ]),
442- "time_difference" : float (time_comparison ["time_difference" ]),
443- "faster_model" : time_comparison ["faster_model" ],
444480 }
445481 )
446482
483+ # Log metadata for ZenML dashboard
484+ log_metadata (metadata = metadata_dict )
485+
447486 html_visualization = create_summary_visualization (
448487 avg_metrics = avg_metrics ,
449488 time_comparison = time_comparison ,
@@ -456,30 +495,33 @@ def evaluate_models(
456495 return html_visualization
457496
458497 # FALLBACK: if no ground truth metrics, only use processing times.
459- model1_times = model1_df .select ("processing_time" ).to_series ().mean ()
460- model2_times = model2_df .select ("processing_time" ).to_series ().mean ()
461- model1_time_key = f"avg_{ model1_prefix } _time"
462- model2_time_key = f"avg_{ model2_prefix } _time"
463498 time_comparison = {
464- model1_time_key : model1_times ,
465- model2_time_key : model2_times ,
466- "time_difference" : abs (model1_times - model2_times ),
467- "faster_model" : model1_display if model1_times < model2_times else model2_display ,
499+ ** all_model_times ,
500+ "time_difference" : abs (
501+ all_model_times [f"avg_{ model1_prefix } _time" ]
502+ - all_model_times [f"avg_{ model2_prefix } _time" ]
503+ ),
504+ "fastest_model" : fastest_model_display ,
468505 }
506+
469507 html_visualization = create_summary_visualization (
470508 avg_metrics = {},
471509 time_comparison = time_comparison ,
472510 model1_name = model1_name ,
473511 model2_name = model2_name ,
474512 )
475513
476- log_metadata (
477- metadata = {
478- model1_time_key : float (time_comparison [model1_time_key ]),
479- model2_time_key : float (time_comparison [model2_time_key ]),
480- "time_difference" : float (time_comparison ["time_difference" ]),
481- "faster_model" : time_comparison ["faster_model" ],
482- }
483- )
514+ # Prepare metadata for ZenML dashboard
515+ metadata_dict = {
516+ ** {
517+ f"avg_{ model } _time" : float (time )
518+ for model , time in all_model_times .items ()
519+ if not model .endswith ("_display" )
520+ },
521+ "fastest_model" : fastest_model_display ,
522+ "model_count" : len (model_results ),
523+ }
524+
525+ log_metadata (metadata = metadata_dict )
484526
485527 return html_visualization
0 commit comments