Skip to content

Commit 8bd57cb

Browse files
authored
Merge pull request #278 from uriahf/273-small-changes-for-better-calibration-plots
273 small changes for better calibration plots
2 parents 725e7d7 + 1c45c8a commit 8bd57cb

File tree

3 files changed

+3341
-1935
lines changed

3 files changed

+3341
-1935
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ license = {text = "MIT"}
66
requires-python = ">=3.9"
77
dependencies = [
88
"plotly<6.0.0,>=5.13.1",
9-
"polars>=1.28.0",
109
"pandas>=2.2.3",
1110
"typing>=3.7.4.3",
1211
"polarstate==0.1.8",
1312
"marimo>=0.17.0",
1413
"pyarrow>=21.0.0",
1514
"statsmodels>=0.14.0",
15+
"polars>=1.31.0",
1616
]
1717
name = "rtichoke"
18-
version = "0.1.26"
18+
version = "0.1.27"
1919
description = "interactive visualizations for performance of predictive models"
2020
readme = "README.md"
2121

src/rtichoke/calibration/calibration.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,6 @@ def _create_plotly_curve_from_calibration_curve_list(
329329
pl.col("reference_group") == reference_group
330330
)
331331

332-
print(dec_sub)
333-
334332
calibration_curve.add_trace(
335333
go.Scatter(
336334
x=dec_sub.get_column("x").to_list(),
@@ -438,8 +436,6 @@ def _create_plotly_curve_from_calibration_curve_list(
438436
col=1,
439437
)
440438

441-
print(calibration_curve_list["axes_ranges"]["xaxis"])
442-
443439
calibration_curve.update_xaxes(
444440
zeroline=True,
445441
range=calibration_curve_list["axes_ranges"]["xaxis"],
@@ -458,10 +454,6 @@ def _create_plotly_curve_from_calibration_curve_list(
458454
)
459455
calibration_curve.update_yaxes(title="Observed", row=1, col=1)
460456

461-
print("size")
462-
print(calibration_curve_list["size"])
463-
print(calibration_curve_list["size"][0])
464-
465457
calibration_curve.update_layout(
466458
width=calibration_curve_list["size"][0][0],
467459
height=calibration_curve_list["size"][0][0],
@@ -534,18 +526,18 @@ def _make_deciles_dat_binary(
534526

535527
df = pl.concat(frames, how="vertical")
536528

537-
labels = [str(i) for i in range(1, n_bins + 1)]
538-
539529
df = df.with_columns(
540530
[
541531
pl.col("prob").cast(pl.Float64),
542532
pl.col("real").cast(pl.Float64),
543-
pl.col("prob")
544-
.qcut(n_bins, labels=labels, allow_duplicates=True)
545-
.over(["reference_group", "model"])
546-
.alias("decile"),
533+
(
534+
(pl.col("prob").rank("ordinal").over(["reference_group", "model"]) - 1)
535+
* n_bins
536+
// pl.count().over(["reference_group", "model"])
537+
+ 1
538+
).alias("decile"),
547539
]
548-
).with_columns(pl.col("decile").cast(pl.Int32))
540+
)
549541

550542
deciles_data = (
551543
df.group_by(["reference_group", "model", "decile"])
@@ -616,12 +608,8 @@ def _create_calibration_curve_list(
616608
reference_groups, color_values, performance_type
617609
)
618610

619-
print("histogram for calibration")
620-
621611
histogram_for_calibration = _create_histogram_for_calibration(probs)
622612

623-
print(histogram_for_calibration)
624-
625613
limits = _define_limits_for_calibration_plot(deciles_data)
626614
axes_ranges = {"xaxis": limits, "yaxis": limits}
627615

0 commit comments

Comments
 (0)