Skip to content

Commit 838c033

Browse files
committed
fix: close #198
1 parent e40f815 commit 838c033

File tree

1 file changed

+40
-12
lines changed

1 file changed

+40
-12
lines changed

src/rtichoke/helpers/sandbox_observable_helpers.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -824,21 +824,50 @@ def _create_list_data_to_adjust_binary(
824824
by,
825825
) -> Dict[str, pl.DataFrame]:
826826
reference_group_labels = list(probs_dict.keys())
827-
num_reals = len(reals_dict)
827+
828+
if isinstance(reals_dict, dict):
829+
num_keys_reals = len(reals_dict)
830+
else:
831+
num_keys_reals = 1
828832

829833
reference_group_enum = pl.Enum(reference_group_labels)
830834

831835
strata_enum_dtype = aj_data_combinations.schema["strata"]
832836

833-
data_to_adjust = pl.DataFrame(
834-
{
835-
"reference_group": np.repeat(reference_group_labels, num_reals),
836-
"probs": np.concatenate(
837-
[probs_dict[group] for group in reference_group_labels]
838-
),
839-
"reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
840-
}
841-
).with_columns(pl.col("reference_group").cast(reference_group_enum))
837+
if len(probs_dict) == 1:
838+
probs_array = np.asarray(probs_dict[reference_group_labels[0]])
839+
840+
data_to_adjust = pl.DataFrame(
841+
{
842+
"reference_group": np.repeat(reference_group_labels, len(probs_array)),
843+
"probs": probs_array,
844+
"reals": reals_dict,
845+
}
846+
).with_columns(pl.col("reference_group").cast(reference_group_enum))
847+
848+
elif num_keys_reals == 1:
849+
data_to_adjust = pl.DataFrame(
850+
{
851+
"reference_group": np.repeat(reference_group_labels, len(reals_dict)),
852+
"probs": np.concatenate(
853+
[probs_dict[group] for group in reference_group_labels]
854+
),
855+
"reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
856+
}
857+
).with_columns(pl.col("reference_group").cast(reference_group_enum))
858+
859+
elif isinstance(reals_dict, dict):
860+
data_to_adjust = (
861+
pl.DataFrame(
862+
{
863+
"reference_group": list(probs_dict.keys()),
864+
"probs": list(probs_dict.values()),
865+
"reals": list(reals_dict.values()),
866+
}
867+
)
868+
.explode(["probs", "reals"])
869+
.with_columns(pl.col("reference_group").cast(reference_group_enum))
870+
)
842871

843872
data_to_adjust = add_cutoff_strata(
844873
data_to_adjust, by=by, stratified_by=stratified_by
@@ -873,7 +902,6 @@ def _create_list_data_to_adjust_binary(
873902
.alias("reals_labels")
874903
)
875904

876-
# Partition by reference_group
877905
list_data_to_adjust = {
878906
group[0]: df
879907
for group, df in data_to_adjust.partition_by(
@@ -1029,7 +1057,7 @@ def _create_adjusted_data_binary(
10291057

10301058
adjusted_data_binary = (
10311059
long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"])
1032-
.agg(pl.sum("reals").alias("reals_estimate"))
1060+
.agg(pl.count().alias("reals_estimate"))
10331061
.join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
10341062
)
10351063

0 commit comments

Comments
 (0)