Skip to content

Commit 22b48f4

Browse files
authored
Merge pull request #208 from uriahf/195-update-functions-for-discrimination-plots
fix: add fill_null with 0 values
2 parents be0c3aa + 10a4d3a commit 22b48f4

File tree

3 files changed

+150
-27
lines changed

3 files changed

+150
-27
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ dependencies = [
1919
"pandas>=2.2.3",
2020
"typing>=3.7.4.3",
2121
"polarstate==0.1.8",
22+
"marimo>=0.17.0",
2223
]
2324
name = "rtichoke"
24-
version = "0.1.16"
25+
version = "0.1.17"
2526
description = "interactive visualizations for performance of predictive models"
2627
readme = "README.md"
2728

src/rtichoke/helpers/sandbox_observable_helpers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,8 @@ def _cast_and_join_adjusted_data_binary(
12041204
)
12051205
).alias("classification_outcome")
12061206
)
1207-
)
1207+
).with_columns(pl.col("reals_estimate").fill_null(0))
1208+
12081209
return final_adjusted_data_polars
12091210

12101211

@@ -1562,6 +1563,17 @@ def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame:
15621563
)
15631564
.agg([pl.col("reals_estimate").sum()])
15641565
.pivot(on="classification_outcome", values="reals_estimate")
1566+
.with_columns(
1567+
[
1568+
pl.col(col).fill_null(0)
1569+
for col in [
1570+
"true_positives",
1571+
"true_negatives",
1572+
"false_positives",
1573+
"false_negatives",
1574+
]
1575+
]
1576+
)
15651577
.with_columns(
15661578
(pl.col("true_positives") + pl.col("false_positives")).alias(
15671579
"predicted_positives"
@@ -1678,8 +1690,11 @@ def _turn_cumulative_aj_to_performance_data(
16781690
(pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"),
16791691
(pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"),
16801692
(pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"),
1693+
(pl.col("false_positives") / pl.col("real_negatives")).alias(
1694+
"false_positive_rate"
1695+
),
16811696
(
1682-
(pl.col("true_positives") / pl.col("real_positives"))
1697+
(pl.col("true_positives") / pl.col("predicted_positives"))
16831698
/ (pl.col("real_positives") / pl.col("n"))
16841699
).alias("lift"),
16851700
pl.when(pl.col("stratified_by") == "probability_threshold")
@@ -1692,6 +1707,15 @@ def _turn_cumulative_aj_to_performance_data(
16921707
.otherwise(None)
16931708
.alias("net_benefit"),
16941709
pl.when(pl.col("stratified_by") == "probability_threshold")
1710+
.then(
1711+
100 * (pl.col("true_negatives") / pl.col("n"))
1712+
- (pl.col("false_negatives") / pl.col("n"))
1713+
* (1 - pl.col("chosen_cutoff"))
1714+
/ pl.col("chosen_cutoff")
1715+
)
1716+
.otherwise(None)
1717+
.alias("net_benefit_interventions_avoided"),
1718+
pl.when(pl.col("stratified_by") == "probability_threshold")
16951719
.then(pl.col("predicted_positives") / pl.col("n"))
16961720
.otherwise(pl.col("chosen_cutoff"))
16971721
.alias("ppcr"),

0 commit comments

Comments
 (0)