Skip to content

Commit 1286284

Browse files
authored
Merge pull request #211 from uriahf/195-update-functions-for-discrimination-plots
195 update functions for discrimination plots
2 parents 22b48f4 + 676a3bc commit 1286284

File tree

3 files changed

+219
-12
lines changed

3 files changed

+219
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"marimo>=0.17.0",
2323
]
2424
name = "rtichoke"
25-
version = "0.1.17"
25+
version = "0.1.18"
2626
description = "interactive visualizations for performance of predictive models"
2727
readme = "README.md"
2828

src/rtichoke/helpers/plotly_helper_functions.py

Lines changed: 217 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99
import numpy as np
1010
from rtichoke.performance_data.performance_data import prepare_performance_data
1111

12+
_HOVER_LABELS = {
13+
"false_positive_rate": "1 - Specificity (FPR)",
14+
"sensitivity": "Sensitivity",
15+
"specificity": "Specificity",
16+
"lift": "Lift",
17+
"ppv": "PPV",
18+
"npv": "NPV",
19+
"net_benefit": "NB",
20+
"net_benefit_interventions_avoided": "Interventions Avoided (per 100)",
21+
"chosen_cutoff": "Prob. Threshold",
22+
"ppcr": "Predicted Positives",
23+
}
24+
1225

1326
def _create_rtichoke_plotly_curve_binary(
1427
probs: Dict[str, np.ndarray],
@@ -887,6 +900,160 @@ def _check_if_multiple_populations_are_being_validated(
887900
return aj_estimates["aj_estimate"].unique().len() > 1
888901

889902

903+
def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) -> bool:
904+
return aj_estimates["reference_group"].unique().len() > 1
905+
906+
907+
def _infer_performance_data_type(
908+
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
909+
) -> str:
910+
multiple_models = _check_if_multiple_models_are_being_validated(
911+
aj_estimates_from_performance_data
912+
)
913+
914+
if multiple_populations:
915+
return "several populations"
916+
elif multiple_models:
917+
return "several models"
918+
else:
919+
return "single model"
920+
921+
922+
def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str:
923+
lines = text.split("<br>")
924+
for metric in metrics:
925+
label = _HOVER_LABELS.get(metric, metric)
926+
lines = [
927+
f"<b>{line}</b>" if label in line and "<b>" not in line else line
928+
for line in lines
929+
]
930+
return "<br>".join(lines)
931+
932+
933+
def _add_model_population_text(text: str, row: dict, perf_dat_type: str) -> str:
934+
if perf_dat_type == "several models" and "model" in row:
935+
text = f"<b>Model: {row['model']}</b><br>{text}"
936+
if perf_dat_type == "several populations" and "population" in row:
937+
text = f"<b>Population: {row['population']}</b><br>{text}"
938+
return text
939+
940+
941+
def _round_val(value: Any, digits: int = 3):
942+
try:
943+
if value is None:
944+
return ""
945+
if isinstance(value, (int, float, np.floating)):
946+
return round(float(value), digits)
947+
except (TypeError, ValueError):
948+
pass
949+
return value
950+
951+
952+
def _build_hover_text(
953+
row: dict,
954+
performance_metric_x: str,
955+
performance_metric_y: str,
956+
stratified_by: str,
957+
perf_dat_type: str,
958+
) -> str:
959+
interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"
960+
961+
raw_probability_threshold = row.get("chosen_cutoff")
962+
probability_threshold = _round_val(raw_probability_threshold)
963+
sensitivity = _round_val(row.get("sensitivity"))
964+
fpr = _round_val(row.get("false_positive_rate"))
965+
specificity = _round_val(row.get("specificity"))
966+
lift = _round_val(row.get("lift"))
967+
ppv = _round_val(row.get("ppv"))
968+
npv = _round_val(row.get("npv"))
969+
net_benefit = _round_val(row.get("net_benefit"))
970+
nb_interventions_avoided = _round_val(row.get("net_benefit_interventions_avoided"))
971+
predicted_positives = _round_val(row.get("predicted_positives"))
972+
raw_ppcr = row.get("ppcr")
973+
ppcr_percent = (
974+
_round_val(100 * raw_ppcr)
975+
if isinstance(raw_ppcr, (int, float, np.floating))
976+
else ""
977+
)
978+
tp = _round_val(row.get("true_positives"))
979+
tn = _round_val(row.get("true_negatives"))
980+
fp = _round_val(row.get("false_positives"))
981+
fn = _round_val(row.get("false_negatives"))
982+
983+
if (
984+
isinstance(raw_probability_threshold, (int, float, np.floating))
985+
and raw_probability_threshold != 0
986+
):
987+
odds = _round_val(
988+
(1 - raw_probability_threshold) / raw_probability_threshold, 2
989+
)
990+
else:
991+
odds = None
992+
993+
if not interventions_avoided:
994+
text_lines = [
995+
f"Prob. Threshold: {probability_threshold}",
996+
f"Sensitivity: {sensitivity}",
997+
f"1 - Specificity (FPR): {fpr}",
998+
f"Specificity: {specificity}",
999+
f"Lift: {lift}",
1000+
f"PPV: {ppv}",
1001+
f"NPV: {npv}",
1002+
]
1003+
if stratified_by == "probability_threshold":
1004+
text_lines.append(f"NB: {net_benefit}")
1005+
if odds is not None and math.isfinite(float(odds)):
1006+
text_lines.append(f"Odds of Prob. Threshold: 1:{odds}")
1007+
text_lines.extend(
1008+
[
1009+
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
1010+
f"TP: {tp}",
1011+
f"TN: {tn}",
1012+
f"FP: {fp}",
1013+
f"FN: {fn}",
1014+
]
1015+
)
1016+
else:
1017+
text_lines = [
1018+
f"Prob. Threshold: {probability_threshold}",
1019+
f"Interventions Avoided (per 100): {nb_interventions_avoided}",
1020+
f"NB: {net_benefit}",
1021+
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
1022+
f"TN: {tn}",
1023+
f"FN: {fn}",
1024+
]
1025+
if odds is not None and math.isfinite(float(odds)):
1026+
text_lines.insert(1, f"Odds of Prob. Threshold: 1:{odds}")
1027+
1028+
text = "<br>".join(text_lines)
1029+
text = _bold_hover_metrics(text, [performance_metric_x, performance_metric_y])
1030+
text = _add_model_population_text(text, row, perf_dat_type)
1031+
return text.replace("NaN", "").replace("nan", "")
1032+
1033+
1034+
def _add_hover_text_to_performance_data(
1035+
performance_data: pl.DataFrame,
1036+
performance_metric_x: str,
1037+
performance_metric_y: str,
1038+
stratified_by: str,
1039+
perf_dat_type: str,
1040+
) -> pl.DataFrame:
1041+
hover_text_expr = pl.struct(performance_data.columns).map_elements(
1042+
lambda row: _build_hover_text(
1043+
row,
1044+
performance_metric_x=performance_metric_x,
1045+
performance_metric_y=performance_metric_y,
1046+
stratified_by=stratified_by,
1047+
perf_dat_type=perf_dat_type,
1048+
),
1049+
return_dtype=pl.Utf8,
1050+
)
1051+
1052+
return performance_data.with_columns(
1053+
[pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
1054+
)
1055+
1056+
8901057
def _create_rtichoke_curve_list_binary(
8911058
performance_data: pl.DataFrame,
8921059
stratified_by: str,
@@ -904,10 +1071,6 @@ def _create_rtichoke_curve_list_binary(
9041071

9051072
x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]
9061073

907-
performance_data_ready_for_curve = _select_and_rename_necessary_variables(
908-
performance_data.sort("chosen_cutoff"), x_metric, y_metric
909-
)
910-
9111074
aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data(
9121075
performance_data
9131076
)
@@ -916,6 +1079,28 @@ def _create_rtichoke_curve_list_binary(
9161079
aj_estimates_from_performance_data
9171080
)
9181081

1082+
multiple_models = _check_if_multiple_models_are_being_validated(
1083+
aj_estimates_from_performance_data
1084+
)
1085+
1086+
perf_dat_type = _infer_performance_data_type(
1087+
aj_estimates_from_performance_data, multiple_populations
1088+
)
1089+
1090+
multiple_reference_groups = multiple_populations or multiple_models
1091+
1092+
performance_data_with_hover_text = _add_hover_text_to_performance_data(
1093+
performance_data.sort("chosen_cutoff"),
1094+
performance_metric_x=x_metric,
1095+
performance_metric_y=y_metric,
1096+
stratified_by=stratified_by,
1097+
perf_dat_type=perf_dat_type,
1098+
)
1099+
1100+
performance_data_ready_for_curve = _select_and_rename_necessary_variables(
1101+
performance_data_with_hover_text, x_metric, y_metric
1102+
)
1103+
9191104
reference_data = _create_reference_lines_data(
9201105
curve=curve,
9211106
aj_estimates_from_performance_data=aj_estimates_from_performance_data,
@@ -976,7 +1161,9 @@ def _create_rtichoke_curve_list_binary(
9761161
]
9771162
},
9781163
**{
979-
variant_key: (palette[group_index] if multiple_populations else "#000000")
1164+
variant_key: (
1165+
palette[group_index] if multiple_reference_groups else "#000000"
1166+
)
9801167
for group_index, reference_group in enumerate(reference_group_keys)
9811168
for variant_key in [
9821169
reference_group,
@@ -999,7 +1186,7 @@ def _create_rtichoke_curve_list_binary(
9991186
"reference_data": reference_data,
10001187
"cutoffs": cutoffs,
10011188
"colors_dictionary": colors_dictionary,
1002-
"multiple_populations": multiple_populations,
1189+
"multiple_reference_groups": multiple_reference_groups,
10031190
}
10041191

10051192
return rtichoke_curve_list
@@ -1013,6 +1200,7 @@ def _select_and_rename_necessary_variables(
10131200
pl.col("chosen_cutoff"),
10141201
pl.col(x_perf_metric).alias("x"),
10151202
pl.col(y_perf_metric).alias("y"),
1203+
pl.col("text"),
10161204
)
10171205

10181206

@@ -1047,13 +1235,22 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10471235
y=rtichoke_curve_list["performance_data_ready_for_curve"]
10481236
.filter(pl.col("reference_group") == group)["y"]
10491237
.to_list(),
1238+
text=rtichoke_curve_list["performance_data_ready_for_curve"].filter(
1239+
pl.col("reference_group") == group
1240+
)["text"],
10501241
mode="markers+lines",
10511242
name=group,
10521243
legendgroup=group,
10531244
line={
10541245
"width": 2,
10551246
"color": rtichoke_curve_list["colors_dictionary"].get(group),
10561247
},
1248+
hoverlabel=dict(
1249+
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
1250+
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
1251+
font_color="white",
1252+
),
1253+
hoverinfo="text",
10571254
showlegend=True,
10581255
)
10591256
for group in rtichoke_curve_list["reference_group_keys"]
@@ -1068,15 +1265,20 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10681265
"size": 12,
10691266
"color": (
10701267
rtichoke_curve_list["colors_dictionary"].get(group)
1071-
if rtichoke_curve_list["multiple_populations"]
1268+
if rtichoke_curve_list["multiple_reference_groups"]
10721269
else "#f6e3be"
10731270
),
10741271
"line": {"width": 3, "color": "black"},
10751272
},
10761273
name=f"{group} @ cutoff",
10771274
legendgroup=group,
1275+
hoverlabel=dict(
1276+
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
1277+
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
1278+
font_color="white",
1279+
),
10781280
showlegend=False,
1079-
hovertemplate=f"{group}<br>x=%{{x:.4f}}<br>y=%{{y:.4f}}<extra></extra>",
1281+
hoverinfo="text",
10801282
)
10811283
for group in rtichoke_curve_list["reference_group_keys"]
10821284
]
@@ -1097,6 +1299,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
10971299
color=rtichoke_curve_list["colors_dictionary"].get(group),
10981300
width=1.5,
10991301
),
1302+
hoverlabel=dict(
1303+
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
1304+
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
1305+
font_color="white",
1306+
),
11001307
hoverinfo="text",
11011308
text=rtichoke_curve_list["reference_data"]
11021309
.filter(pl.col("reference_group") == group)["text"]
@@ -1183,7 +1390,7 @@ def _create_curve_layout(
11831390
curve_layout = {
11841391
"xaxis": {"showgrid": False},
11851392
"yaxis": {"showgrid": False},
1186-
"template": "none",
1393+
"template": "plotly",
11871394
"plot_bgcolor": "rgba(0, 0, 0, 0)",
11881395
"paper_bgcolor": "rgba(0, 0, 0, 0)",
11891396
"showlegend": True,
@@ -1198,7 +1405,7 @@ def _create_curve_layout(
11981405
},
11991406
"height": size + 50,
12001407
"width": size,
1201-
"hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
1408+
# "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
12021409
"updatemenus": [
12031410
{
12041411
"type": "buttons",

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)