Skip to content

Commit e61d62b

Browse files
authored
Merge pull request #230 from uriahf/199-create-create_rtichoke_curve_list_times-function
199 create create rtichoke curve list times function
2 parents 7a582ff + 6836add commit e61d62b

File tree

6 files changed

+255
-22
lines changed

6 files changed

+255
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"pyarrow>=21.0.0",
1515
]
1616
name = "rtichoke"
17-
version = "0.1.22"
17+
version = "0.1.23"
1818
description = "interactive visualizations for performance of predictive models"
1919
readme = "README.md"
2020

src/rtichoke/helpers/plotly_helper_functions.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,38 @@ def _get_aj_estimates_from_performance_data(
898898
)
899899

900900

901+
def _get_aj_estimates_from_performance_data_times(
902+
performance_data: pl.DataFrame,
903+
) -> pl.DataFrame:
904+
return (
905+
performance_data.filter(
906+
(pl.col("chosen_cutoff") == 0) | (pl.col("chosen_cutoff") == 1)
907+
)
908+
.select("reference_group", "fixed_time_horizon", "real_positives", "n")
909+
.unique()
910+
.with_columns((pl.col("real_positives") / pl.col("n")).alias("aj_estimate"))
911+
.select(
912+
pl.col("reference_group"),
913+
pl.col("fixed_time_horizon"),
914+
pl.col("aj_estimate"),
915+
)
916+
.sort(by=["reference_group", "fixed_time_horizon"])
917+
)
918+
919+
920+
def _check_if_multiple_populations_are_being_validated_times(
921+
aj_estimates: pl.DataFrame,
922+
) -> bool:
923+
max_val = (
924+
aj_estimates.group_by("fixed_time_horizon")
925+
.agg(pl.col("aj_estimate").n_unique().alias("num_populations"))[
926+
"num_populations"
927+
]
928+
.max()
929+
)
930+
return max_val is not None and max_val > 1
931+
932+
901933
def _check_if_multiple_populations_are_being_validated(
902934
aj_estimates: pl.DataFrame,
903935
) -> bool:
@@ -908,6 +940,21 @@ def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) ->
908940
return aj_estimates["reference_group"].unique().len() > 1
909941

910942

943+
def _infer_performance_data_type_times(
944+
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
945+
) -> str:
946+
multiple_models = _check_if_multiple_populations_are_being_validated_times(
947+
aj_estimates_from_performance_data
948+
)
949+
950+
if multiple_populations:
951+
return "several populations"
952+
elif multiple_models:
953+
return "several models"
954+
else:
955+
return "single model"
956+
957+
911958
def _infer_performance_data_type(
912959
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
913960
) -> str:
@@ -1058,6 +1105,146 @@ def _add_hover_text_to_performance_data(
10581105
)
10591106

10601107

1108+
def _create_rtichoke_curve_list_times(
1109+
performance_data: pl.DataFrame,
1110+
stratified_by: str,
1111+
size: int = 500,
1112+
color_value=None,
1113+
curve="roc",
1114+
min_p_threshold=0,
1115+
max_p_threshold=1,
1116+
) -> dict[str, Any]:
1117+
animation_slider_cutoff_prefix = (
1118+
"Prob. Threshold: "
1119+
if stratified_by == "probability_threshold"
1120+
else "Predicted Positives (Rate):"
1121+
)
1122+
1123+
x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]
1124+
1125+
aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data_times(
1126+
performance_data
1127+
)
1128+
1129+
print("aj_estimates_from_performance_data", aj_estimates_from_performance_data)
1130+
1131+
multiple_populations = _check_if_multiple_populations_are_being_validated_times(
1132+
aj_estimates_from_performance_data
1133+
)
1134+
1135+
multiple_models = _check_if_multiple_models_are_being_validated(
1136+
aj_estimates_from_performance_data
1137+
)
1138+
1139+
perf_dat_type = _infer_performance_data_type_times(
1140+
aj_estimates_from_performance_data, multiple_populations
1141+
)
1142+
1143+
multiple_reference_groups = multiple_populations or multiple_models
1144+
1145+
performance_data_with_hover_text = _add_hover_text_to_performance_data(
1146+
performance_data.sort("chosen_cutoff"),
1147+
performance_metric_x=x_metric,
1148+
performance_metric_y=y_metric,
1149+
stratified_by=stratified_by,
1150+
perf_dat_type=perf_dat_type,
1151+
)
1152+
1153+
performance_data_ready_for_curve = _select_and_rename_necessary_variables(
1154+
performance_data_with_hover_text, x_metric, y_metric
1155+
)
1156+
1157+
reference_data = _create_reference_lines_data(
1158+
curve=curve,
1159+
aj_estimates_from_performance_data=aj_estimates_from_performance_data,
1160+
multiple_populations=multiple_populations,
1161+
min_p_threshold=min_p_threshold,
1162+
max_p_threshold=max_p_threshold,
1163+
)
1164+
1165+
axes_ranges = extract_axes_ranges(
1166+
performance_data_ready_for_curve,
1167+
curve=curve,
1168+
min_p_threshold=min_p_threshold,
1169+
max_p_threshold=max_p_threshold,
1170+
)
1171+
1172+
reference_group_keys = performance_data["reference_group"].unique().to_list()
1173+
1174+
cutoffs = (
1175+
performance_data_ready_for_curve.select(pl.col("chosen_cutoff"))
1176+
.drop_nulls()
1177+
.unique()
1178+
.sort("chosen_cutoff")
1179+
.to_series()
1180+
.to_list()
1181+
)
1182+
1183+
palette = [
1184+
"#1b9e77",
1185+
"#d95f02",
1186+
"#7570b3",
1187+
"#e7298a",
1188+
"#07004D",
1189+
"#E6AB02",
1190+
"#FE5F55",
1191+
"#54494B",
1192+
"#006E90",
1193+
"#BC96E6",
1194+
"#52050A",
1195+
"#1F271B",
1196+
"#BE7C4D",
1197+
"#63768D",
1198+
"#08A045",
1199+
"#320A28",
1200+
"#82FF9E",
1201+
"#2176FF",
1202+
"#D1603D",
1203+
"#585123",
1204+
]
1205+
1206+
colors_dictionary = {
1207+
**{
1208+
key: "#BEBEBE"
1209+
for key in [
1210+
"random_guess",
1211+
"perfect_model",
1212+
"treat_none",
1213+
"treat_all",
1214+
]
1215+
},
1216+
**{
1217+
variant_key: (
1218+
palette[group_index] if multiple_reference_groups else "#000000"
1219+
)
1220+
for group_index, reference_group in enumerate(reference_group_keys)
1221+
for variant_key in [
1222+
reference_group,
1223+
f"random_guess_{reference_group}",
1224+
f"perfect_model_{reference_group}",
1225+
f"treat_none_{reference_group}",
1226+
f"treat_all_{reference_group}",
1227+
]
1228+
},
1229+
}
1230+
1231+
rtichoke_curve_list = {
1232+
"size": size,
1233+
"axes_ranges": axes_ranges,
1234+
"x_label": x_label,
1235+
"y_label": y_label,
1236+
"animation_slider_cutoff_prefix": animation_slider_cutoff_prefix,
1237+
"reference_group_keys": reference_group_keys,
1238+
"performance_data_ready_for_curve": performance_data_ready_for_curve,
1239+
"reference_data": reference_data,
1240+
"cutoffs": cutoffs,
1241+
"colors_dictionary": colors_dictionary,
1242+
"multiple_reference_groups": multiple_reference_groups,
1243+
}
1244+
1245+
return rtichoke_curve_list
1246+
1247+
10611248
def _create_rtichoke_curve_list_binary(
10621249
performance_data: pl.DataFrame,
10631250
stratified_by: str,

src/rtichoke/helpers/sandbox_observable_helpers.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ def _create_list_data_to_adjust_binary(
912912
return list_data_to_adjust
913913

914914

915-
def create_list_data_to_adjust(
915+
def _create_list_data_to_adjust(
916916
aj_data_combinations: pl.DataFrame,
917917
probs_dict: Dict[str, np.ndarray],
918918
reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
@@ -922,25 +922,70 @@ def create_list_data_to_adjust(
922922
) -> Dict[str, pl.DataFrame]:
923923
# reference_groups = list(probs_dict.keys())
924924
reference_group_labels = list(probs_dict.keys())
925-
num_reals = len(reals_dict)
925+
926+
if isinstance(reals_dict, dict):
927+
num_keys_reals = len(reals_dict)
928+
else:
929+
num_keys_reals = 1
930+
931+
# num_reals = len(reals_dict)
926932

927933
reference_group_enum = pl.Enum(reference_group_labels)
928934

929935
strata_enum_dtype = aj_data_combinations.schema["strata"]
930936

931-
# Flatten and ensure list format
932-
data_to_adjust = pl.DataFrame(
933-
{
934-
"reference_group": np.repeat(reference_group_labels, num_reals),
935-
"probs": np.concatenate(
936-
[probs_dict[group] for group in reference_group_labels]
937-
),
938-
"reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
939-
"times": np.tile(np.asarray(times_dict), len(reference_group_labels)),
940-
}
941-
).with_columns(pl.col("reference_group").cast(reference_group_enum))
937+
if len(probs_dict) == 1:
938+
probs_array = np.asarray(probs_dict[reference_group_labels[0]])
939+
940+
if isinstance(reals_dict, dict):
941+
reals_array = np.asarray(reals_dict[0])
942+
else:
943+
reals_array = np.asarray(reals_dict)
944+
945+
if isinstance(times_dict, dict):
946+
times_array = np.asarray(times_dict[0])
947+
else:
948+
times_array = np.asarray(times_dict)
949+
950+
data_to_adjust = pl.DataFrame(
951+
{
952+
"reference_group": np.repeat(reference_group_labels, len(probs_array)),
953+
"probs": probs_array,
954+
"reals": reals_array,
955+
"times": times_array,
956+
}
957+
).with_columns(pl.col("reference_group").cast(reference_group_enum))
958+
959+
elif num_keys_reals == 1:
960+
reals_array = np.asarray(reals_dict)
961+
times_array = np.asarray(times_dict)
962+
n = len(reals_array)
963+
964+
data_to_adjust = pl.DataFrame(
965+
{
966+
"reference_group": np.repeat(reference_group_labels, n),
967+
"probs": np.concatenate(
968+
[np.asarray(probs_dict[g]) for g in reference_group_labels]
969+
),
970+
"reals": np.tile(reals_array, len(reference_group_labels)),
971+
"times": np.tile(times_array, len(reference_group_labels)),
972+
}
973+
).with_columns(pl.col("reference_group").cast(reference_group_enum))
974+
975+
elif isinstance(reals_dict, dict) and isinstance(times_dict, dict):
976+
data_to_adjust = (
977+
pl.DataFrame(
978+
{
979+
"reference_group": reference_group_labels,
980+
"probs": list(probs_dict.values()),
981+
"reals": list(reals_dict.values()),
982+
"times": list(times_dict.values()),
983+
}
984+
)
985+
.explode(["probs", "reals", "times"])
986+
.with_columns(pl.col("reference_group").cast(reference_group_enum))
987+
)
942988

943-
# Apply strata
944989
data_to_adjust = add_cutoff_strata(
945990
data_to_adjust, by=by, stratified_by=stratified_by
946991
)
@@ -1637,6 +1682,7 @@ def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame:
16371682
)
16381683
.agg([pl.col("reals_estimate").sum()])
16391684
.pivot(on="classification_outcome", values="reals_estimate")
1685+
.fill_null(0)
16401686
.with_columns(
16411687
(pl.col("true_positives") + pl.col("false_positives")).alias(
16421688
"predicted_positives"

src/rtichoke/performance_data/performance_data_times.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from rtichoke.helpers.sandbox_observable_helpers import (
99
create_breaks_values,
1010
create_aj_data_combinations,
11-
create_list_data_to_adjust,
11+
_create_list_data_to_adjust,
1212
create_adjusted_data,
1313
cast_and_join_adjusted_data,
1414
_calculate_cumulative_aj_data,
@@ -154,7 +154,7 @@ def prepare_binned_classification_data_times(
154154
risk_set_scope=risk_set_scope,
155155
)
156156

157-
list_data_to_adjust = create_list_data_to_adjust(
157+
list_data_to_adjust = _create_list_data_to_adjust(
158158
aj_data_combinations,
159159
probs,
160160
reals,
@@ -175,6 +175,6 @@ def prepare_binned_classification_data_times(
175175
final_adjusted_data = cast_and_join_adjusted_data(
176176
aj_data_combinations,
177177
adjusted_data,
178-
)
178+
).with_columns(pl.col("reals_estimate").fill_null(0.0))
179179

180180
return final_adjusted_data

src/rtichoke/summary_report/summary_report.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
66
from rtichoke.helpers.sandbox_observable_helpers import (
7-
create_list_data_to_adjust,
7+
_create_list_data_to_adjust,
88
)
99
import subprocess
1010

@@ -67,8 +67,8 @@ def create_data_for_summary_report(probs, reals, times, fixed_time_horizons):
6767
stratified_by = ["probability_threshold", "ppcr"]
6868
by = 0.1
6969

70-
list_data_to_adjust_polars = create_list_data_to_adjust(
71-
probs, reals, times, stratified_by=stratified_by, by=by
70+
list_data_to_adjust_polars = _create_list_data_to_adjust(
71+
probs, reals, times, stratified_by=stratified_by, by=by, times_dict={}
7272
)
7373

7474
return list_data_to_adjust_polars

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)