@@ -898,6 +898,38 @@ def _get_aj_estimates_from_performance_data(
898898 )
899899
900900
901+ def _get_aj_estimates_from_performance_data_times (
902+ performance_data : pl .DataFrame ,
903+ ) -> pl .DataFrame :
904+ return (
905+ performance_data .filter (
906+ (pl .col ("chosen_cutoff" ) == 0 ) | (pl .col ("chosen_cutoff" ) == 1 )
907+ )
908+ .select ("reference_group" , "fixed_time_horizon" , "real_positives" , "n" )
909+ .unique ()
910+ .with_columns ((pl .col ("real_positives" ) / pl .col ("n" )).alias ("aj_estimate" ))
911+ .select (
912+ pl .col ("reference_group" ),
913+ pl .col ("fixed_time_horizon" ),
914+ pl .col ("aj_estimate" ),
915+ )
916+ .sort (by = ["reference_group" , "fixed_time_horizon" ])
917+ )
918+
919+
920+ def _check_if_multiple_populations_are_being_validated_times (
921+ aj_estimates : pl .DataFrame ,
922+ ) -> bool :
923+ max_val = (
924+ aj_estimates .group_by ("fixed_time_horizon" )
925+ .agg (pl .col ("aj_estimate" ).n_unique ().alias ("num_populations" ))[
926+ "num_populations"
927+ ]
928+ .max ()
929+ )
930+ return max_val is not None and max_val > 1
931+
932+
901933def _check_if_multiple_populations_are_being_validated (
902934 aj_estimates : pl .DataFrame ,
903935) -> bool :
@@ -908,6 +940,21 @@ def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) ->
908940 return aj_estimates ["reference_group" ].unique ().len () > 1
909941
910942
943+ def _infer_performance_data_type_times (
944+ aj_estimates_from_performance_data : pl .DataFrame , multiple_populations : bool
945+ ) -> str :
946+ multiple_models = _check_if_multiple_populations_are_being_validated_times (
947+ aj_estimates_from_performance_data
948+ )
949+
950+ if multiple_populations :
951+ return "several populations"
952+ elif multiple_models :
953+ return "several models"
954+ else :
955+ return "single model"
956+
957+
911958def _infer_performance_data_type (
912959 aj_estimates_from_performance_data : pl .DataFrame , multiple_populations : bool
913960) -> str :
@@ -1058,6 +1105,146 @@ def _add_hover_text_to_performance_data(
10581105 )
10591106
10601107
1108+ def _create_rtichoke_curve_list_times (
1109+ performance_data : pl .DataFrame ,
1110+ stratified_by : str ,
1111+ size : int = 500 ,
1112+ color_value = None ,
1113+ curve = "roc" ,
1114+ min_p_threshold = 0 ,
1115+ max_p_threshold = 1 ,
1116+ ) -> dict [str , Any ]:
1117+ animation_slider_cutoff_prefix = (
1118+ "Prob. Threshold: "
1119+ if stratified_by == "probability_threshold"
1120+ else "Predicted Positives (Rate):"
1121+ )
1122+
1123+ x_metric , y_metric , x_label , y_label = _CURVE_CONFIG [curve ]
1124+
1125+ aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data_times (
1126+ performance_data
1127+ )
1128+
1129+ print ("aj_estimates_from_performance_data" , aj_estimates_from_performance_data )
1130+
1131+ multiple_populations = _check_if_multiple_populations_are_being_validated_times (
1132+ aj_estimates_from_performance_data
1133+ )
1134+
1135+ multiple_models = _check_if_multiple_models_are_being_validated (
1136+ aj_estimates_from_performance_data
1137+ )
1138+
1139+ perf_dat_type = _infer_performance_data_type_times (
1140+ aj_estimates_from_performance_data , multiple_populations
1141+ )
1142+
1143+ multiple_reference_groups = multiple_populations or multiple_models
1144+
1145+ performance_data_with_hover_text = _add_hover_text_to_performance_data (
1146+ performance_data .sort ("chosen_cutoff" ),
1147+ performance_metric_x = x_metric ,
1148+ performance_metric_y = y_metric ,
1149+ stratified_by = stratified_by ,
1150+ perf_dat_type = perf_dat_type ,
1151+ )
1152+
1153+ performance_data_ready_for_curve = _select_and_rename_necessary_variables (
1154+ performance_data_with_hover_text , x_metric , y_metric
1155+ )
1156+
1157+ reference_data = _create_reference_lines_data (
1158+ curve = curve ,
1159+ aj_estimates_from_performance_data = aj_estimates_from_performance_data ,
1160+ multiple_populations = multiple_populations ,
1161+ min_p_threshold = min_p_threshold ,
1162+ max_p_threshold = max_p_threshold ,
1163+ )
1164+
1165+ axes_ranges = extract_axes_ranges (
1166+ performance_data_ready_for_curve ,
1167+ curve = curve ,
1168+ min_p_threshold = min_p_threshold ,
1169+ max_p_threshold = max_p_threshold ,
1170+ )
1171+
1172+ reference_group_keys = performance_data ["reference_group" ].unique ().to_list ()
1173+
1174+ cutoffs = (
1175+ performance_data_ready_for_curve .select (pl .col ("chosen_cutoff" ))
1176+ .drop_nulls ()
1177+ .unique ()
1178+ .sort ("chosen_cutoff" )
1179+ .to_series ()
1180+ .to_list ()
1181+ )
1182+
1183+ palette = [
1184+ "#1b9e77" ,
1185+ "#d95f02" ,
1186+ "#7570b3" ,
1187+ "#e7298a" ,
1188+ "#07004D" ,
1189+ "#E6AB02" ,
1190+ "#FE5F55" ,
1191+ "#54494B" ,
1192+ "#006E90" ,
1193+ "#BC96E6" ,
1194+ "#52050A" ,
1195+ "#1F271B" ,
1196+ "#BE7C4D" ,
1197+ "#63768D" ,
1198+ "#08A045" ,
1199+ "#320A28" ,
1200+ "#82FF9E" ,
1201+ "#2176FF" ,
1202+ "#D1603D" ,
1203+ "#585123" ,
1204+ ]
1205+
1206+ colors_dictionary = {
1207+ ** {
1208+ key : "#BEBEBE"
1209+ for key in [
1210+ "random_guess" ,
1211+ "perfect_model" ,
1212+ "treat_none" ,
1213+ "treat_all" ,
1214+ ]
1215+ },
1216+ ** {
1217+ variant_key : (
1218+ palette [group_index ] if multiple_reference_groups else "#000000"
1219+ )
1220+ for group_index , reference_group in enumerate (reference_group_keys )
1221+ for variant_key in [
1222+ reference_group ,
1223+ f"random_guess_{ reference_group } " ,
1224+ f"perfect_model_{ reference_group } " ,
1225+ f"treat_none_{ reference_group } " ,
1226+ f"treat_all_{ reference_group } " ,
1227+ ]
1228+ },
1229+ }
1230+
1231+ rtichoke_curve_list = {
1232+ "size" : size ,
1233+ "axes_ranges" : axes_ranges ,
1234+ "x_label" : x_label ,
1235+ "y_label" : y_label ,
1236+ "animation_slider_cutoff_prefix" : animation_slider_cutoff_prefix ,
1237+ "reference_group_keys" : reference_group_keys ,
1238+ "performance_data_ready_for_curve" : performance_data_ready_for_curve ,
1239+ "reference_data" : reference_data ,
1240+ "cutoffs" : cutoffs ,
1241+ "colors_dictionary" : colors_dictionary ,
1242+ "multiple_reference_groups" : multiple_reference_groups ,
1243+ }
1244+
1245+ return rtichoke_curve_list
1246+
1247+
10611248def _create_rtichoke_curve_list_binary (
10621249 performance_data : pl .DataFrame ,
10631250 stratified_by : str ,
0 commit comments