99import numpy as np
1010from rtichoke .performance_data .performance_data import prepare_performance_data
1111
12+ _HOVER_LABELS = {
13+ "false_positive_rate" : "1 - Specificity (FPR)" ,
14+ "sensitivity" : "Sensitivity" ,
15+ "specificity" : "Specificity" ,
16+ "lift" : "Lift" ,
17+ "ppv" : "PPV" ,
18+ "npv" : "NPV" ,
19+ "net_benefit" : "NB" ,
20+ "net_benefit_interventions_avoided" : "Interventions Avoided (per 100)" ,
21+ "chosen_cutoff" : "Prob. Threshold" ,
22+ "ppcr" : "Predicted Positives" ,
23+ }
24+
1225
1326def _create_rtichoke_plotly_curve_binary (
1427 probs : Dict [str , np .ndarray ],
@@ -887,6 +900,160 @@ def _check_if_multiple_populations_are_being_validated(
887900 return aj_estimates ["aj_estimate" ].unique ().len () > 1
888901
889902
903+ def _check_if_multiple_models_are_being_validated (aj_estimates : pl .DataFrame ) -> bool :
904+ return aj_estimates ["reference_group" ].unique ().len () > 1
905+
906+
907+ def _infer_performance_data_type (
908+ aj_estimates_from_performance_data : pl .DataFrame , multiple_populations : bool
909+ ) -> str :
910+ multiple_models = _check_if_multiple_models_are_being_validated (
911+ aj_estimates_from_performance_data
912+ )
913+
914+ if multiple_populations :
915+ return "several populations"
916+ elif multiple_models :
917+ return "several models"
918+ else :
919+ return "single model"
920+
921+
922+ def _bold_hover_metrics (text : str , metrics : Sequence [str ]) -> str :
923+ lines = text .split ("<br>" )
924+ for metric in metrics :
925+ label = _HOVER_LABELS .get (metric , metric )
926+ lines = [
927+ f"<b>{ line } </b>" if label in line and "<b>" not in line else line
928+ for line in lines
929+ ]
930+ return "<br>" .join (lines )
931+
932+
933+ def _add_model_population_text (text : str , row : dict , perf_dat_type : str ) -> str :
934+ if perf_dat_type == "several models" and "model" in row :
935+ text = f"<b>Model: { row ['model' ]} </b><br>{ text } "
936+ if perf_dat_type == "several populations" and "population" in row :
937+ text = f"<b>Population: { row ['population' ]} </b><br>{ text } "
938+ return text
939+
940+
941+ def _round_val (value : Any , digits : int = 3 ):
942+ try :
943+ if value is None :
944+ return ""
945+ if isinstance (value , (int , float , np .floating )):
946+ return round (float (value ), digits )
947+ except (TypeError , ValueError ):
948+ pass
949+ return value
950+
951+
952+ def _build_hover_text (
953+ row : dict ,
954+ performance_metric_x : str ,
955+ performance_metric_y : str ,
956+ stratified_by : str ,
957+ perf_dat_type : str ,
958+ ) -> str :
959+ interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"
960+
961+ raw_probability_threshold = row .get ("chosen_cutoff" )
962+ probability_threshold = _round_val (raw_probability_threshold )
963+ sensitivity = _round_val (row .get ("sensitivity" ))
964+ fpr = _round_val (row .get ("false_positive_rate" ))
965+ specificity = _round_val (row .get ("specificity" ))
966+ lift = _round_val (row .get ("lift" ))
967+ ppv = _round_val (row .get ("ppv" ))
968+ npv = _round_val (row .get ("npv" ))
969+ net_benefit = _round_val (row .get ("net_benefit" ))
970+ nb_interventions_avoided = _round_val (row .get ("net_benefit_interventions_avoided" ))
971+ predicted_positives = _round_val (row .get ("predicted_positives" ))
972+ raw_ppcr = row .get ("ppcr" )
973+ ppcr_percent = (
974+ _round_val (100 * raw_ppcr )
975+ if isinstance (raw_ppcr , (int , float , np .floating ))
976+ else ""
977+ )
978+ tp = _round_val (row .get ("true_positives" ))
979+ tn = _round_val (row .get ("true_negatives" ))
980+ fp = _round_val (row .get ("false_positives" ))
981+ fn = _round_val (row .get ("false_negatives" ))
982+
983+ if (
984+ isinstance (raw_probability_threshold , (int , float , np .floating ))
985+ and raw_probability_threshold != 0
986+ ):
987+ odds = _round_val (
988+ (1 - raw_probability_threshold ) / raw_probability_threshold , 2
989+ )
990+ else :
991+ odds = None
992+
993+ if not interventions_avoided :
994+ text_lines = [
995+ f"Prob. Threshold: { probability_threshold } " ,
996+ f"Sensitivity: { sensitivity } " ,
997+ f"1 - Specificity (FPR): { fpr } " ,
998+ f"Specificity: { specificity } " ,
999+ f"Lift: { lift } " ,
1000+ f"PPV: { ppv } " ,
1001+ f"NPV: { npv } " ,
1002+ ]
1003+ if stratified_by == "probability_threshold" :
1004+ text_lines .append (f"NB: { net_benefit } " )
1005+ if odds is not None and math .isfinite (float (odds )):
1006+ text_lines .append (f"Odds of Prob. Threshold: 1:{ odds } " )
1007+ text_lines .extend (
1008+ [
1009+ f"Predicted Positives: { predicted_positives } ({ ppcr_percent } %)" ,
1010+ f"TP: { tp } " ,
1011+ f"TN: { tn } " ,
1012+ f"FP: { fp } " ,
1013+ f"FN: { fn } " ,
1014+ ]
1015+ )
1016+ else :
1017+ text_lines = [
1018+ f"Prob. Threshold: { probability_threshold } " ,
1019+ f"Interventions Avoided (per 100): { nb_interventions_avoided } " ,
1020+ f"NB: { net_benefit } " ,
1021+ f"Predicted Positives: { predicted_positives } ({ ppcr_percent } %)" ,
1022+ f"TN: { tn } " ,
1023+ f"FN: { fn } " ,
1024+ ]
1025+ if odds is not None and math .isfinite (float (odds )):
1026+ text_lines .insert (1 , f"Odds of Prob. Threshold: 1:{ odds } " )
1027+
1028+ text = "<br>" .join (text_lines )
1029+ text = _bold_hover_metrics (text , [performance_metric_x , performance_metric_y ])
1030+ text = _add_model_population_text (text , row , perf_dat_type )
1031+ return text .replace ("NaN" , "" ).replace ("nan" , "" )
1032+
1033+
1034+ def _add_hover_text_to_performance_data (
1035+ performance_data : pl .DataFrame ,
1036+ performance_metric_x : str ,
1037+ performance_metric_y : str ,
1038+ stratified_by : str ,
1039+ perf_dat_type : str ,
1040+ ) -> pl .DataFrame :
1041+ hover_text_expr = pl .struct (performance_data .columns ).map_elements (
1042+ lambda row : _build_hover_text (
1043+ row ,
1044+ performance_metric_x = performance_metric_x ,
1045+ performance_metric_y = performance_metric_y ,
1046+ stratified_by = stratified_by ,
1047+ perf_dat_type = perf_dat_type ,
1048+ ),
1049+ return_dtype = pl .Utf8 ,
1050+ )
1051+
1052+ return performance_data .with_columns (
1053+ [pl .col (pl .FLOAT_DTYPES ).round (3 ), hover_text_expr .alias ("text" )]
1054+ )
1055+
1056+
8901057def _create_rtichoke_curve_list_binary (
8911058 performance_data : pl .DataFrame ,
8921059 stratified_by : str ,
@@ -904,10 +1071,6 @@ def _create_rtichoke_curve_list_binary(
9041071
9051072 x_metric , y_metric , x_label , y_label = _CURVE_CONFIG [curve ]
9061073
907- performance_data_ready_for_curve = _select_and_rename_necessary_variables (
908- performance_data .sort ("chosen_cutoff" ), x_metric , y_metric
909- )
910-
9111074 aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data (
9121075 performance_data
9131076 )
@@ -916,6 +1079,28 @@ def _create_rtichoke_curve_list_binary(
9161079 aj_estimates_from_performance_data
9171080 )
9181081
1082+ multiple_models = _check_if_multiple_models_are_being_validated (
1083+ aj_estimates_from_performance_data
1084+ )
1085+
1086+ perf_dat_type = _infer_performance_data_type (
1087+ aj_estimates_from_performance_data , multiple_populations
1088+ )
1089+
1090+ multiple_reference_groups = multiple_populations or multiple_models
1091+
1092+ performance_data_with_hover_text = _add_hover_text_to_performance_data (
1093+ performance_data .sort ("chosen_cutoff" ),
1094+ performance_metric_x = x_metric ,
1095+ performance_metric_y = y_metric ,
1096+ stratified_by = stratified_by ,
1097+ perf_dat_type = perf_dat_type ,
1098+ )
1099+
1100+ performance_data_ready_for_curve = _select_and_rename_necessary_variables (
1101+ performance_data_with_hover_text , x_metric , y_metric
1102+ )
1103+
9191104 reference_data = _create_reference_lines_data (
9201105 curve = curve ,
9211106 aj_estimates_from_performance_data = aj_estimates_from_performance_data ,
@@ -976,7 +1161,9 @@ def _create_rtichoke_curve_list_binary(
9761161 ]
9771162 },
9781163 ** {
979- variant_key : (palette [group_index ] if multiple_populations else "#000000" )
1164+ variant_key : (
1165+ palette [group_index ] if multiple_reference_groups else "#000000"
1166+ )
9801167 for group_index , reference_group in enumerate (reference_group_keys )
9811168 for variant_key in [
9821169 reference_group ,
@@ -999,7 +1186,7 @@ def _create_rtichoke_curve_list_binary(
9991186 "reference_data" : reference_data ,
10001187 "cutoffs" : cutoffs ,
10011188 "colors_dictionary" : colors_dictionary ,
1002- "multiple_populations " : multiple_populations ,
1189+ "multiple_reference_groups " : multiple_reference_groups ,
10031190 }
10041191
10051192 return rtichoke_curve_list
@@ -1013,6 +1200,7 @@ def _select_and_rename_necessary_variables(
10131200 pl .col ("chosen_cutoff" ),
10141201 pl .col (x_perf_metric ).alias ("x" ),
10151202 pl .col (y_perf_metric ).alias ("y" ),
1203+ pl .col ("text" ),
10161204 )
10171205
10181206
@@ -1047,13 +1235,22 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10471235 y = rtichoke_curve_list ["performance_data_ready_for_curve" ]
10481236 .filter (pl .col ("reference_group" ) == group )["y" ]
10491237 .to_list (),
1238+ text = rtichoke_curve_list ["performance_data_ready_for_curve" ].filter (
1239+ pl .col ("reference_group" ) == group
1240+ )["text" ],
10501241 mode = "markers+lines" ,
10511242 name = group ,
10521243 legendgroup = group ,
10531244 line = {
10541245 "width" : 2 ,
10551246 "color" : rtichoke_curve_list ["colors_dictionary" ].get (group ),
10561247 },
1248+ hoverlabel = dict (
1249+ bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1250+ bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1251+ font_color = "white" ,
1252+ ),
1253+ hoverinfo = "text" ,
10571254 showlegend = True ,
10581255 )
10591256 for group in rtichoke_curve_list ["reference_group_keys" ]
@@ -1068,15 +1265,20 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10681265 "size" : 12 ,
10691266 "color" : (
10701267 rtichoke_curve_list ["colors_dictionary" ].get (group )
1071- if rtichoke_curve_list ["multiple_populations " ]
1268+ if rtichoke_curve_list ["multiple_reference_groups " ]
10721269 else "#f6e3be"
10731270 ),
10741271 "line" : {"width" : 3 , "color" : "black" },
10751272 },
10761273 name = f"{ group } @ cutoff" ,
10771274 legendgroup = group ,
1275+ hoverlabel = dict (
1276+ bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1277+ bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1278+ font_color = "white" ,
1279+ ),
10781280 showlegend = False ,
1079- hovertemplate = f" { group } <br>x=%{{x:.4f}}<br>y=%{{y:.4f}}<extra></extra> " ,
1281+ hoverinfo = "text " ,
10801282 )
10811283 for group in rtichoke_curve_list ["reference_group_keys" ]
10821284 ]
@@ -1097,6 +1299,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10971299 color = rtichoke_curve_list ["colors_dictionary" ].get (group ),
10981300 width = 1.5 ,
10991301 ),
1302+ hoverlabel = dict (
1303+ bgcolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1304+ bordercolor = rtichoke_curve_list ["colors_dictionary" ].get (group ),
1305+ font_color = "white" ,
1306+ ),
11001307 hoverinfo = "text" ,
11011308 text = rtichoke_curve_list ["reference_data" ]
11021309 .filter (pl .col ("reference_group" ) == group )["text" ]
@@ -1183,7 +1390,7 @@ def _create_curve_layout(
11831390 curve_layout = {
11841391 "xaxis" : {"showgrid" : False },
11851392 "yaxis" : {"showgrid" : False },
1186- "template" : "none " ,
1393+ "template" : "plotly " ,
11871394 "plot_bgcolor" : "rgba(0, 0, 0, 0)" ,
11881395 "paper_bgcolor" : "rgba(0, 0, 0, 0)" ,
11891396 "showlegend" : True ,
@@ -1198,7 +1405,7 @@ def _create_curve_layout(
11981405 },
11991406 "height" : size + 50 ,
12001407 "width" : size ,
1201- "hoverlabel" : {"bgcolor" : "rgba(0,0,0,0)" , "bordercolor" : "rgba(0,0,0,0)" },
1408+ # "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
12021409 "updatemenus" : [
12031410 {
12041411 "type" : "buttons" ,
0 commit comments