@@ -824,21 +824,50 @@ def _create_list_data_to_adjust_binary(
824824 by ,
825825) -> Dict [str , pl .DataFrame ]:
826826 reference_group_labels = list (probs_dict .keys ())
827- num_reals = len (reals_dict )
827+
828+ if isinstance (reals_dict , dict ):
829+ num_keys_reals = len (reals_dict )
830+ else :
831+ num_keys_reals = 1
828832
829833 reference_group_enum = pl .Enum (reference_group_labels )
830834
831835 strata_enum_dtype = aj_data_combinations .schema ["strata" ]
832836
833- data_to_adjust = pl .DataFrame (
834- {
835- "reference_group" : np .repeat (reference_group_labels , num_reals ),
836- "probs" : np .concatenate (
837- [probs_dict [group ] for group in reference_group_labels ]
838- ),
839- "reals" : np .tile (np .asarray (reals_dict ), len (reference_group_labels )),
840- }
841- ).with_columns (pl .col ("reference_group" ).cast (reference_group_enum ))
837+ if len (probs_dict ) == 1 :
838+ probs_array = np .asarray (probs_dict [reference_group_labels [0 ]])
839+
840+ data_to_adjust = pl .DataFrame (
841+ {
842+ "reference_group" : np .repeat (reference_group_labels , len (probs_array )),
843+ "probs" : probs_array ,
844+ "reals" : reals_dict ,
845+ }
846+ ).with_columns (pl .col ("reference_group" ).cast (reference_group_enum ))
847+
848+ elif num_keys_reals == 1 :
849+ data_to_adjust = pl .DataFrame (
850+ {
851+ "reference_group" : np .repeat (reference_group_labels , len (reals_dict )),
852+ "probs" : np .concatenate (
853+ [probs_dict [group ] for group in reference_group_labels ]
854+ ),
855+ "reals" : np .tile (np .asarray (reals_dict ), len (reference_group_labels )),
856+ }
857+ ).with_columns (pl .col ("reference_group" ).cast (reference_group_enum ))
858+
859+ elif isinstance (reals_dict , dict ):
860+ data_to_adjust = (
861+ pl .DataFrame (
862+ {
863+ "reference_group" : list (probs_dict .keys ()),
864+ "probs" : list (probs_dict .values ()),
865+ "reals" : list (reals_dict .values ()),
866+ }
867+ )
868+ .explode (["probs" , "reals" ])
869+ .with_columns (pl .col ("reference_group" ).cast (reference_group_enum ))
870+ )
842871
843872 data_to_adjust = add_cutoff_strata (
844873 data_to_adjust , by = by , stratified_by = stratified_by
@@ -873,7 +902,6 @@ def _create_list_data_to_adjust_binary(
873902 .alias ("reals_labels" )
874903 )
875904
876- # Partition by reference_group
877905 list_data_to_adjust = {
878906 group [0 ]: df
879907 for group , df in data_to_adjust .partition_by (
@@ -1029,7 +1057,7 @@ def _create_adjusted_data_binary(
10291057
10301058 adjusted_data_binary = (
10311059 long_df .group_by (["strata" , "stratified_by" , "reference_group" , "reals_labels" ])
1032- .agg (pl .sum ( "reals" ).alias ("reals_estimate" ))
1060+ .agg (pl .count ( ).alias ("reals_estimate" ))
10331061 .join (pl .DataFrame ({"chosen_cutoff" : breaks }), how = "cross" )
10341062 )
10351063
0 commit comments