Skip to content

Commit 127c9e2

Browse files
committed
fix: close #210
1 parent dbe8891 commit 127c9e2

File tree

1 file changed

+179
-16
lines changed

1 file changed

+179
-16
lines changed

src/rtichoke/helpers/plotly_helper_functions.py

Lines changed: 179 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99
import numpy as np
1010
from rtichoke.performance_data.performance_data import prepare_performance_data
1111

12+
_HOVER_LABELS = {
13+
"false_positive_rate": "1 - Specificity (FPR)",
14+
"sensitivity": "Sensitivity",
15+
"specificity": "Specificity",
16+
"lift": "Lift",
17+
"ppv": "PPV",
18+
"npv": "NPV",
19+
"net_benefit": "NB",
20+
"net_benefit_interventions_avoided": "Interventions Avoided (per 100)",
21+
"chosen_cutoff": "Prob. Threshold",
22+
"ppcr": "Predicted Positives",
23+
}
24+
1225

1326
def _create_rtichoke_plotly_curve_binary(
1427
probs: Dict[str, np.ndarray],
@@ -887,6 +900,149 @@ def _check_if_multiple_populations_are_being_validated(
887900
return aj_estimates["aj_estimate"].unique().len() > 1
888901

889902

903+
def _infer_performance_data_type(performance_data: pl.DataFrame) -> str:
904+
if "model" in performance_data.columns:
905+
return "several models"
906+
if "population" in performance_data.columns:
907+
return "several populations"
908+
return "single model"
909+
910+
911+
def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str:
912+
lines = text.split("<br>")
913+
for metric in metrics:
914+
label = _HOVER_LABELS.get(metric, metric)
915+
lines = [
916+
f"<b>{line}</b>" if label in line and "<b>" not in line else line
917+
for line in lines
918+
]
919+
return "<br>".join(lines)
920+
921+
922+
def _add_model_population_text(text: str, row: dict, perf_dat_type: str) -> str:
923+
if perf_dat_type == "several models" and "model" in row:
924+
text = f"<b>Model: {row['model']}</b><br>{text}"
925+
if perf_dat_type == "several populations" and "population" in row:
926+
text = f"<b>Population: {row['population']}</b><br>{text}"
927+
return text
928+
929+
930+
def _round_val(value: Any, digits: int = 3):
931+
try:
932+
if value is None:
933+
return ""
934+
if isinstance(value, (int, float, np.floating)):
935+
return round(float(value), digits)
936+
except (TypeError, ValueError):
937+
pass
938+
return value
939+
940+
941+
def _build_hover_text(
942+
row: dict,
943+
performance_metric_x: str,
944+
performance_metric_y: str,
945+
stratified_by: str,
946+
perf_dat_type: str,
947+
) -> str:
948+
interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"
949+
950+
raw_probability_threshold = row.get("chosen_cutoff")
951+
probability_threshold = _round_val(raw_probability_threshold)
952+
sensitivity = _round_val(row.get("sensitivity"))
953+
fpr = _round_val(row.get("false_positive_rate"))
954+
specificity = _round_val(row.get("specificity"))
955+
lift = _round_val(row.get("lift"))
956+
ppv = _round_val(row.get("ppv"))
957+
npv = _round_val(row.get("npv"))
958+
net_benefit = _round_val(row.get("net_benefit"))
959+
nb_interventions_avoided = _round_val(row.get("net_benefit_interventions_avoided"))
960+
predicted_positives = _round_val(row.get("predicted_positives"))
961+
raw_ppcr = row.get("ppcr")
962+
ppcr_percent = (
963+
_round_val(100 * raw_ppcr)
964+
if isinstance(raw_ppcr, (int, float, np.floating))
965+
else ""
966+
)
967+
tp = _round_val(row.get("true_positives"))
968+
tn = _round_val(row.get("true_negatives"))
969+
fp = _round_val(row.get("false_positives"))
970+
fn = _round_val(row.get("false_negatives"))
971+
972+
if (
973+
isinstance(raw_probability_threshold, (int, float, np.floating))
974+
and raw_probability_threshold != 0
975+
):
976+
odds = _round_val(
977+
(1 - raw_probability_threshold) / raw_probability_threshold, 2
978+
)
979+
else:
980+
odds = None
981+
982+
if not interventions_avoided:
983+
text_lines = [
984+
f"Prob. Threshold: {probability_threshold}",
985+
f"Sensitivity: {sensitivity}",
986+
f"1 - Specificity (FPR): {fpr}",
987+
f"Specificity: {specificity}",
988+
f"Lift: {lift}",
989+
f"PPV: {ppv}",
990+
f"NPV: {npv}",
991+
]
992+
if stratified_by == "probability_threshold":
993+
text_lines.append(f"NB: {net_benefit}")
994+
if odds is not None and math.isfinite(float(odds)):
995+
text_lines.append(f"Odds of Prob. Threshold: 1:{odds}")
996+
text_lines.extend(
997+
[
998+
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
999+
f"TP: {tp}",
1000+
f"TN: {tn}",
1001+
f"FP: {fp}",
1002+
f"FN: {fn}",
1003+
]
1004+
)
1005+
else:
1006+
text_lines = [
1007+
f"Prob. Threshold: {probability_threshold}",
1008+
f"Interventions Avoided (per 100): {nb_interventions_avoided}",
1009+
f"NB: {net_benefit}",
1010+
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
1011+
f"TN: {tn}",
1012+
f"FN: {fn}",
1013+
]
1014+
if odds is not None and math.isfinite(float(odds)):
1015+
text_lines.insert(1, f"Odds of Prob. Threshold: 1:{odds}")
1016+
1017+
text = "<br>".join(text_lines)
1018+
text = _bold_hover_metrics(text, [performance_metric_x, performance_metric_y])
1019+
text = _add_model_population_text(text, row, perf_dat_type)
1020+
return text.replace("NaN", "").replace("nan", "")
1021+
1022+
1023+
def _add_hover_text_to_performance_data(
1024+
performance_data: pl.DataFrame,
1025+
performance_metric_x: str,
1026+
performance_metric_y: str,
1027+
stratified_by: str,
1028+
perf_dat_type: str,
1029+
) -> pl.DataFrame:
1030+
hover_text_expr = pl.struct(performance_data.columns).map_elements(
1031+
lambda row: _build_hover_text(
1032+
row,
1033+
performance_metric_x=performance_metric_x,
1034+
performance_metric_y=performance_metric_y,
1035+
stratified_by=stratified_by,
1036+
perf_dat_type=perf_dat_type,
1037+
),
1038+
return_dtype=pl.Utf8,
1039+
)
1040+
1041+
return performance_data.with_columns(
1042+
[pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
1043+
)
1044+
1045+
8901046
def _create_rtichoke_curve_list_binary(
8911047
performance_data: pl.DataFrame,
8921048
stratified_by: str,
@@ -904,8 +1060,18 @@ def _create_rtichoke_curve_list_binary(
9041060

9051061
x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]
9061062

1063+
perf_dat_type = _infer_performance_data_type(performance_data)
1064+
1065+
performance_data_with_hover_text = _add_hover_text_to_performance_data(
1066+
performance_data.sort("chosen_cutoff"),
1067+
performance_metric_x=x_metric,
1068+
performance_metric_y=y_metric,
1069+
stratified_by=stratified_by,
1070+
perf_dat_type=perf_dat_type,
1071+
)
1072+
9071073
performance_data_ready_for_curve = _select_and_rename_necessary_variables(
908-
performance_data.sort("chosen_cutoff"), x_metric, y_metric
1074+
performance_data_with_hover_text, x_metric, y_metric
9091075
)
9101076

9111077
aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data(
@@ -1013,6 +1179,7 @@ def _select_and_rename_necessary_variables(
10131179
pl.col("chosen_cutoff"),
10141180
pl.col(x_perf_metric).alias("x"),
10151181
pl.col(y_perf_metric).alias("y"),
1182+
pl.col("text"),
10161183
)
10171184

10181185

@@ -1047,6 +1214,9 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10471214
y=rtichoke_curve_list["performance_data_ready_for_curve"]
10481215
.filter(pl.col("reference_group") == group)["y"]
10491216
.to_list(),
1217+
text=rtichoke_curve_list["performance_data_ready_for_curve"].filter(
1218+
pl.col("reference_group") == group
1219+
)["text"],
10501220
mode="markers+lines",
10511221
name=group,
10521222
legendgroup=group,
@@ -1055,14 +1225,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10551225
"color": rtichoke_curve_list["colors_dictionary"].get(group),
10561226
},
10571227
hoverlabel=dict(
1058-
bgcolor=rtichoke_curve_list["colors_dictionary"].get(
1059-
group
1060-
), # <-- background = trace color
1061-
bordercolor=rtichoke_curve_list["colors_dictionary"].get(
1062-
group
1063-
), # <-- border = trace color
1064-
font_color="white", # <-- or "black" if your colors are light
1228+
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
1229+
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
1230+
font_color="white",
10651231
),
1232+
hoverinfo="text",
10661233
showlegend=True,
10671234
)
10681235
for group in rtichoke_curve_list["reference_group_keys"]
@@ -1085,16 +1252,12 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10851252
name=f"{group} @ cutoff",
10861253
legendgroup=group,
10871254
hoverlabel=dict(
1088-
bgcolor=rtichoke_curve_list["colors_dictionary"].get(
1089-
group
1090-
), # <-- background = trace color
1091-
bordercolor=rtichoke_curve_list["colors_dictionary"].get(
1092-
group
1093-
), # <-- border = trace color
1094-
font_color="white", # <-- or "black" if your colors are light
1255+
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
1256+
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
1257+
font_color="white",
10951258
),
10961259
showlegend=False,
1097-
hovertemplate=f"{group}<br>x=%{{x:.4f}}<br>y=%{{y:.4f}}<extra></extra>",
1260+
hoverinfo="text",
10981261
)
10991262
for group in rtichoke_curve_list["reference_group_keys"]
11001263
]

0 commit comments

Comments
 (0)