@@ -900,12 +900,23 @@ def _check_if_multiple_populations_are_being_validated(
900900 return aj_estimates ["aj_estimate" ].unique ().len () > 1
901901
902902
903- def _infer_performance_data_type (performance_data : pl .DataFrame ) -> str :
904- if "model" in performance_data .columns :
905- return "several models"
906- if "population" in performance_data .columns :
903+ def _check_if_multiple_models_are_being_validated (aj_estimates : pl .DataFrame ) -> bool :
904+ return aj_estimates ["reference_group" ].unique ().len () > 1
905+
906+
907+ def _infer_performance_data_type (
908+ aj_estimates_from_performance_data : pl .DataFrame , multiple_populations : bool
909+ ) -> str :
910+ multiple_models = _check_if_multiple_models_are_being_validated (
911+ aj_estimates_from_performance_data
912+ )
913+
914+ if multiple_populations :
907915 return "several populations"
908- return "single model"
916+ elif multiple_models :
917+ return "several models"
918+ else :
919+ return "single model"
909920
910921
911922def _bold_hover_metrics (text : str , metrics : Sequence [str ]) -> str :
@@ -1060,7 +1071,23 @@ def _create_rtichoke_curve_list_binary(
10601071
10611072 x_metric , y_metric , x_label , y_label = _CURVE_CONFIG [curve ]
10621073
1063- perf_dat_type = _infer_performance_data_type (performance_data )
1074+ aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data (
1075+ performance_data
1076+ )
1077+
1078+ multiple_populations = _check_if_multiple_populations_are_being_validated (
1079+ aj_estimates_from_performance_data
1080+ )
1081+
1082+ multiple_models = _check_if_multiple_models_are_being_validated (
1083+ aj_estimates_from_performance_data
1084+ )
1085+
1086+ perf_dat_type = _infer_performance_data_type (
1087+ aj_estimates_from_performance_data , multiple_populations
1088+ )
1089+
1090+ multiple_reference_groups = multiple_populations or multiple_models
10641091
10651092 performance_data_with_hover_text = _add_hover_text_to_performance_data (
10661093 performance_data .sort ("chosen_cutoff" ),
@@ -1074,14 +1101,6 @@ def _create_rtichoke_curve_list_binary(
10741101 performance_data_with_hover_text , x_metric , y_metric
10751102 )
10761103
1077- aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data (
1078- performance_data
1079- )
1080-
1081- multiple_populations = _check_if_multiple_populations_are_being_validated (
1082- aj_estimates_from_performance_data
1083- )
1084-
10851104 reference_data = _create_reference_lines_data (
10861105 curve = curve ,
10871106 aj_estimates_from_performance_data = aj_estimates_from_performance_data ,
@@ -1142,7 +1161,9 @@ def _create_rtichoke_curve_list_binary(
11421161 ]
11431162 },
11441163 ** {
1145- variant_key : (palette [group_index ] if multiple_populations else "#000000" )
1164+ variant_key : (
1165+ palette [group_index ] if multiple_reference_groups else "#000000"
1166+ )
11461167 for group_index , reference_group in enumerate (reference_group_keys )
11471168 for variant_key in [
11481169 reference_group ,
@@ -1165,7 +1186,7 @@ def _create_rtichoke_curve_list_binary(
11651186 "reference_data" : reference_data ,
11661187 "cutoffs" : cutoffs ,
11671188 "colors_dictionary" : colors_dictionary ,
1168- "multiple_populations " : multiple_populations ,
1189+ "multiple_reference_groups " : multiple_reference_groups ,
11691190 }
11701191
11711192 return rtichoke_curve_list
@@ -1244,7 +1265,7 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
12441265 "size" : 12 ,
12451266 "color" : (
12461267 rtichoke_curve_list ["colors_dictionary" ].get (group )
1247- if rtichoke_curve_list ["multiple_populations " ]
1268+ if rtichoke_curve_list ["multiple_reference_groups " ]
12481269 else "#f6e3be"
12491270 ),
12501271 "line" : {"width" : 3 , "color" : "black" },
0 commit comments