99import numpy as np
1010from rtichoke .performance_data .performance_data import prepare_performance_data
1111
12+ _HOVER_LABELS = {
13+ "false_positive_rate" : "1 - Specificity (FPR)" ,
14+ "sensitivity" : "Sensitivity" ,
15+ "specificity" : "Specificity" ,
16+ "lift" : "Lift" ,
17+ "ppv" : "PPV" ,
18+ "npv" : "NPV" ,
19+ "net_benefit" : "NB" ,
20+ "net_benefit_interventions_avoided" : "Interventions Avoided (per 100)" ,
21+ "chosen_cutoff" : "Prob. Threshold" ,
22+ "ppcr" : "Predicted Positives" ,
23+ }
24+
1225
1326def _create_rtichoke_plotly_curve_binary (
1427 probs : Dict [str , np .ndarray ],
@@ -887,6 +900,149 @@ def _check_if_multiple_populations_are_being_validated(
887900 return aj_estimates ["aj_estimate" ].unique ().len () > 1
888901
889902
903+ def _infer_performance_data_type (performance_data : pl .DataFrame ) -> str :
904+ if "model" in performance_data .columns :
905+ return "several models"
906+ if "population" in performance_data .columns :
907+ return "several populations"
908+ return "single model"
909+
910+
911+ def _bold_hover_metrics (text : str , metrics : Sequence [str ]) -> str :
912+ lines = text .split ("<br>" )
913+ for metric in metrics :
914+ label = _HOVER_LABELS .get (metric , metric )
915+ lines = [
916+ f"<b>{ line } </b>" if label in line and "<b>" not in line else line
917+ for line in lines
918+ ]
919+ return "<br>" .join (lines )
920+
921+
922+ def _add_model_population_text (text : str , row : dict , perf_dat_type : str ) -> str :
923+ if perf_dat_type == "several models" and "model" in row :
924+ text = f"<b>Model: { row ['model' ]} </b><br>{ text } "
925+ if perf_dat_type == "several populations" and "population" in row :
926+ text = f"<b>Population: { row ['population' ]} </b><br>{ text } "
927+ return text
928+
929+
930+ def _round_val (value : Any , digits : int = 3 ):
931+ try :
932+ if value is None :
933+ return ""
934+ if isinstance (value , (int , float , np .floating )):
935+ return round (float (value ), digits )
936+ except (TypeError , ValueError ):
937+ pass
938+ return value
939+
940+
941+ def _build_hover_text (
942+ row : dict ,
943+ performance_metric_x : str ,
944+ performance_metric_y : str ,
945+ stratified_by : str ,
946+ perf_dat_type : str ,
947+ ) -> str :
948+ interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"
949+
950+ raw_probability_threshold = row .get ("chosen_cutoff" )
951+ probability_threshold = _round_val (raw_probability_threshold )
952+ sensitivity = _round_val (row .get ("sensitivity" ))
953+ fpr = _round_val (row .get ("false_positive_rate" ))
954+ specificity = _round_val (row .get ("specificity" ))
955+ lift = _round_val (row .get ("lift" ))
956+ ppv = _round_val (row .get ("ppv" ))
957+ npv = _round_val (row .get ("npv" ))
958+ net_benefit = _round_val (row .get ("net_benefit" ))
959+ nb_interventions_avoided = _round_val (row .get ("net_benefit_interventions_avoided" ))
960+ predicted_positives = _round_val (row .get ("predicted_positives" ))
961+ raw_ppcr = row .get ("ppcr" )
962+ ppcr_percent = (
963+ _round_val (100 * raw_ppcr )
964+ if isinstance (raw_ppcr , (int , float , np .floating ))
965+ else ""
966+ )
967+ tp = _round_val (row .get ("true_positives" ))
968+ tn = _round_val (row .get ("true_negatives" ))
969+ fp = _round_val (row .get ("false_positives" ))
970+ fn = _round_val (row .get ("false_negatives" ))
971+
972+ if (
973+ isinstance (raw_probability_threshold , (int , float , np .floating ))
974+ and raw_probability_threshold != 0
975+ ):
976+ odds = _round_val (
977+ (1 - raw_probability_threshold ) / raw_probability_threshold , 2
978+ )
979+ else :
980+ odds = None
981+
982+ if not interventions_avoided :
983+ text_lines = [
984+ f"Prob. Threshold: { probability_threshold } " ,
985+ f"Sensitivity: { sensitivity } " ,
986+ f"1 - Specificity (FPR): { fpr } " ,
987+ f"Specificity: { specificity } " ,
988+ f"Lift: { lift } " ,
989+ f"PPV: { ppv } " ,
990+ f"NPV: { npv } " ,
991+ ]
992+ if stratified_by == "probability_threshold" :
993+ text_lines .append (f"NB: { net_benefit } " )
994+ if odds is not None and math .isfinite (float (odds )):
995+ text_lines .append (f"Odds of Prob. Threshold: 1:{ odds } " )
996+ text_lines .extend (
997+ [
998+ f"Predicted Positives: { predicted_positives } ({ ppcr_percent } %)" ,
999+ f"TP: { tp } " ,
1000+ f"TN: { tn } " ,
1001+ f"FP: { fp } " ,
1002+ f"FN: { fn } " ,
1003+ ]
1004+ )
1005+ else :
1006+ text_lines = [
1007+ f"Prob. Threshold: { probability_threshold } " ,
1008+ f"Interventions Avoided (per 100): { nb_interventions_avoided } " ,
1009+ f"NB: { net_benefit } " ,
1010+ f"Predicted Positives: { predicted_positives } ({ ppcr_percent } %)" ,
1011+ f"TN: { tn } " ,
1012+ f"FN: { fn } " ,
1013+ ]
1014+ if odds is not None and math .isfinite (float (odds )):
1015+ text_lines .insert (1 , f"Odds of Prob. Threshold: 1:{ odds } " )
1016+
1017+ text = "<br>" .join (text_lines )
1018+ text = _bold_hover_metrics (text , [performance_metric_x , performance_metric_y ])
1019+ text = _add_model_population_text (text , row , perf_dat_type )
1020+ return text .replace ("NaN" , "" ).replace ("nan" , "" )
1021+
1022+
1023+ def _add_hover_text_to_performance_data (
1024+ performance_data : pl .DataFrame ,
1025+ performance_metric_x : str ,
1026+ performance_metric_y : str ,
1027+ stratified_by : str ,
1028+ perf_dat_type : str ,
1029+ ) -> pl .DataFrame :
1030+ hover_text_expr = pl .struct (performance_data .columns ).map_elements (
1031+ lambda row : _build_hover_text (
1032+ row ,
1033+ performance_metric_x = performance_metric_x ,
1034+ performance_metric_y = performance_metric_y ,
1035+ stratified_by = stratified_by ,
1036+ perf_dat_type = perf_dat_type ,
1037+ ),
1038+ return_dtype = pl .Utf8 ,
1039+ )
1040+
1041+ return performance_data .with_columns (
1042+ [pl .col (pl .FLOAT_DTYPES ).round (3 ), hover_text_expr .alias ("text" )]
1043+ )
1044+
1045+
8901046def _create_rtichoke_curve_list_binary (
8911047 performance_data : pl .DataFrame ,
8921048 stratified_by : str ,
@@ -904,8 +1060,18 @@ def _create_rtichoke_curve_list_binary(
9041060
9051061 x_metric , y_metric , x_label , y_label = _CURVE_CONFIG [curve ]
9061062
1063+ perf_dat_type = _infer_performance_data_type (performance_data )
1064+
1065+ performance_data_with_hover_text = _add_hover_text_to_performance_data (
1066+ performance_data .sort ("chosen_cutoff" ),
1067+ performance_metric_x = x_metric ,
1068+ performance_metric_y = y_metric ,
1069+ stratified_by = stratified_by ,
1070+ perf_dat_type = perf_dat_type ,
1071+ )
1072+
9071073 performance_data_ready_for_curve = _select_and_rename_necessary_variables (
908- performance_data . sort ( "chosen_cutoff" ) , x_metric , y_metric
1074+ performance_data_with_hover_text , x_metric , y_metric
9091075 )
9101076
9111077 aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data (
@@ -1013,6 +1179,7 @@ def _select_and_rename_necessary_variables(
10131179 pl .col ("chosen_cutoff" ),
10141180 pl .col (x_perf_metric ).alias ("x" ),
10151181 pl .col (y_perf_metric ).alias ("y" ),
1182+ pl .col ("text" ),
10161183 )
10171184
10181185
@@ -1047,6 +1214,9 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10471214 y = rtichoke_curve_list ["performance_data_ready_for_curve" ]
10481215 .filter (pl .col ("reference_group" ) == group )["y" ]
10491216 .to_list (),
1217+ text = rtichoke_curve_list ["performance_data_ready_for_curve" ].filter (
1218+ pl .col ("reference_group" ) == group
1219+ )["text" ],
10501220 mode = "markers+lines" ,
10511221 name = group ,
10521222 legendgroup = group ,
@@ -1055,14 +1225,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10551225 "color" : rtichoke_curve_list ["colors_dictionary" ].get (group ),
10561226 },
10571227 hoverlabel = dict (
1058- bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (
1059- group
1060- ), # <-- background = trace color
1061- bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (
1062- group
1063- ), # <-- border = trace color
1064- font_color = "white" , # <-- or "black" if your colors are light
1228+ bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1229+ bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1230+ font_color = "white" ,
10651231 ),
1232+ hoverinfo = "text" ,
10661233 showlegend = True ,
10671234 )
10681235 for group in rtichoke_curve_list ["reference_group_keys" ]
@@ -1085,16 +1252,12 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10851252 name = f"{ group } @ cutoff" ,
10861253 legendgroup = group ,
10871254 hoverlabel = dict (
1088- bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (
1089- group
1090- ), # <-- background = trace color
1091- bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (
1092- group
1093- ), # <-- border = trace color
1094- font_color = "white" , # <-- or "black" if your colors are light
1255+ bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1256+ bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1257+ font_color = "white" ,
10951258 ),
10961259 showlegend = False ,
1097- hovertemplate = f" { group } <br>x=%{{x:.4f}}<br>y=%{{y:.4f}}<extra></extra> " ,
1260+ hoverinfo = "text " ,
10981261 )
10991262 for group in rtichoke_curve_list ["reference_group_keys" ]
11001263 ]
0 commit comments