Skip to content

Commit 0731efe

Browse files
authored
Merge pull request #281 from uriahf/273-small-changes-for-better-calibration-plots
273 small changes for better calibration plots
2 parents 8bd57cb + 5686075 commit 0731efe

File tree

6 files changed

+53
-33
lines changed

6 files changed

+53
-33
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"polars>=1.31.0",
1616
]
1717
name = "rtichoke"
18-
version = "0.1.27"
18+
version = "0.1.28"
1919
description = "interactive visualizations for performance of predictive models"
2020
readme = "README.md"
2121

src/rtichoke/calibration/calibration.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A module for Calibration Curves
33
"""
44

5-
from typing import Any, Dict, List, Union
5+
from typing import Any, Dict, List, Union, cast
66

77
# import pandas as pd
88
import plotly.graph_objects as go
@@ -247,6 +247,7 @@ def _create_plotly_curve_from_calibration_curve_list_times(
247247
},
248248
barmode="overlay",
249249
plot_bgcolor="rgba(0, 0, 0, 0)",
250+
paper_bgcolor="rgba(0, 0, 0, 0)",
250251
legend={
251252
"orientation": "h",
252253
"xanchor": "center",
@@ -285,6 +286,7 @@ def _create_plotly_curve_from_calibration_curve_list(
285286
"yaxis": {"showgrid": False},
286287
"barmode": "overlay",
287288
"plot_bgcolor": "rgba(0, 0, 0, 0)",
289+
"paper_bgcolor": "rgba(0, 0, 0, 0)",
288290
"legend": {
289291
"orientation": "h",
290292
"xanchor": "center",
@@ -470,7 +472,7 @@ def _make_deciles_dat_binary(
470472
if isinstance(reals, dict):
471473
reference_groups_keys = list(reals.keys())
472474
y_list = [
473-
np.asarray(reals[reference_group]).ravel()
475+
np.asarray(reals[str(reference_group)]).ravel()
474476
for reference_group in reference_groups_keys
475477
]
476478
lengths = np.array([len(y) for y in y_list], dtype=np.int64)
@@ -533,7 +535,7 @@ def _make_deciles_dat_binary(
533535
(
534536
(pl.col("prob").rank("ordinal").over(["reference_group", "model"]) - 1)
535537
* n_bins
536-
// pl.count().over(["reference_group", "model"])
538+
// pl.len().over(["reference_group", "model"])
537539
+ 1
538540
).alias("decile"),
539541
]
@@ -602,7 +604,7 @@ def _create_calibration_curve_list(
602604

603605
reference_data = _create_reference_data_for_calibration_curve()
604606

605-
reference_groups = deciles_data["reference_group"].unique().to_list()
607+
reference_groups = list(probs.keys())
606608

607609
colors_dictionary = _create_colors_dictionary_for_calibration(
608610
reference_groups, color_values, performance_type
@@ -689,7 +691,9 @@ def process_single_array(p, r, group_name):
689691
for group_name in reals.keys():
690692
if group_name in probs:
691693
frame = process_single_array(
692-
probs[group_name], reals[group_name], group_name
694+
probs[str(group_name)],
695+
reals[str(group_name)],
696+
str(group_name),
693697
)
694698
smooth_frames.append(frame)
695699

@@ -856,8 +860,21 @@ def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float
856860
if deciles_dat.height == 1:
857861
lower_bound, upper_bound = 0.0, 1.0
858862
else:
859-
lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min())))
860-
upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max()))
863+
lower_bound = float(
864+
max(
865+
0,
866+
min(
867+
cast(float, deciles_dat["x"].min()),
868+
cast(float, deciles_dat["y"].min()),
869+
),
870+
)
871+
)
872+
upper_bound = float(
873+
max(
874+
cast(float, deciles_dat["x"].max()),
875+
cast(float, deciles_dat["y"].max()),
876+
)
877+
)
861878

862879
return [
863880
lower_bound - (upper_bound - lower_bound) * 0.05,
@@ -1101,7 +1118,7 @@ def _create_calibration_curve_list_times(
11011118
)
11021119

11031120
reference_data = _create_reference_data_for_calibration_curve()
1104-
reference_groups = deciles_dat_final["reference_group"].unique().to_list()
1121+
reference_groups = list(probs.keys())
11051122
colors_dictionary = _create_colors_dictionary_for_calibration(
11061123
reference_groups, color_values, performance_type
11071124
)

src/rtichoke/processing/exported_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def create_plotly_curve(rtichoke_curve_dict):
148148
"y": 0,
149149
"steps": [],
150150
}
151+
sliders_dict["steps"] = []
151152

152153
for k in range(
153154
len(

src/rtichoke/processing/plotly_helper_functions.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import plotly.graph_objects as go
66
import polars as pl
77
import math
8-
from typing import Any, Dict, Union, Sequence
8+
from typing import Any, Dict, Union, Sequence, cast
99
import numpy as np
1010
from rtichoke.performance_data.performance_data import prepare_performance_data
1111
from rtichoke.performance_data.performance_data_times import (
@@ -329,8 +329,8 @@ def _create_reference_lines_data(
329329
# random-guess (y=1 unless all p==0 -> NaN)
330330
all_zero = (
331331
aj_df["p"].len() > 0
332-
and float(aj_df["p"].max()) == 0.0
333-
and float(aj_df["p"].min()) == 0.0
332+
and float(cast(float, aj_df["p"].max())) == 0.0
333+
and float(cast(float, aj_df["p"].min())) == 0.0
334334
)
335335
rand_y = pl.Series(
336336
np.full(len(x_s), np.nan) if all_zero else np.ones(len(x_s)),
@@ -992,7 +992,7 @@ def _check_if_multiple_populations_are_being_validated_times(
992992
]
993993
.max()
994994
)
995-
return max_val is not None and max_val > 1
995+
return max_val is not None and float(cast(float, max_val)) > 1
996996

997997

998998
def _check_if_multiple_populations_are_being_validated(
@@ -1977,10 +1977,21 @@ def _create_curve_layout(
19771977
"b": max(80, base_pad.get("b", 0)),
19781978
**base_pad,
19791979
}
1980+
xaxis: dict[str, Any] = {"showgrid": False}
1981+
yaxis: dict[str, Any] = {"showgrid": False}
1982+
1983+
if axes_ranges is not None:
1984+
xaxis["range"] = axes_ranges["xaxis"]
1985+
yaxis["range"] = axes_ranges["yaxis"]
1986+
1987+
if x_label:
1988+
xaxis["title"] = {"text": x_label}
1989+
if y_label:
1990+
yaxis["title"] = {"text": y_label}
19801991

19811992
curve_layout = {
1982-
"xaxis": {"showgrid": False},
1983-
"yaxis": {"showgrid": False},
1993+
"xaxis": xaxis,
1994+
"yaxis": yaxis,
19841995
"template": "plotly",
19851996
"plot_bgcolor": "rgba(0, 0, 0, 0)",
19861997
"paper_bgcolor": "rgba(0, 0, 0, 0)",
@@ -2014,15 +2025,6 @@ def _create_curve_layout(
20142025
"modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)},
20152026
}
20162027

2017-
if axes_ranges is not None:
2018-
curve_layout["xaxis"]["range"] = axes_ranges["xaxis"]
2019-
curve_layout["yaxis"]["range"] = axes_ranges["yaxis"]
2020-
2021-
if x_label:
2022-
curve_layout["xaxis"]["title"] = {"text": x_label}
2023-
if y_label:
2024-
curve_layout["yaxis"]["title"] = {"text": y_label}
2025-
20262028
return curve_layout
20272029

20282030

src/rtichoke/processing/transforms.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:
6767

6868
def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame:
6969
# Identify id_vars and value_vars
70-
id_vars = [col for col in data.columns if not col.startswith("strata_")]
71-
value_vars = [col for col in data.columns if col.startswith("strata_")]
70+
index_cols = [col for col in data.columns if not col.startswith("strata_")]
71+
on_cols = [col for col in data.columns if col.startswith("strata_")]
7272

73-
# Perform the melt (equivalent to pandas.melt)
74-
data_long = data.melt(
75-
id_vars=id_vars,
76-
value_vars=value_vars,
73+
# Perform the unpivot (equivalent to pandas.melt)
74+
data_long = data.unpivot(
75+
index=index_cols,
76+
on=on_cols,
7777
variable_name="stratified_by",
7878
value_name="strata",
7979
)
@@ -257,12 +257,12 @@ def _create_list_data_to_adjust(
257257
probs_array = np.asarray(probs_dict[reference_group_labels[0]])
258258

259259
if isinstance(reals_dict, dict):
260-
reals_array = np.asarray(reals_dict[0])
260+
reals_array = np.asarray(reals_dict[reference_group_labels[0]])
261261
else:
262262
reals_array = np.asarray(reals_dict)
263263

264264
if isinstance(times_dict, dict):
265-
times_array = np.asarray(times_dict[0])
265+
times_array = np.asarray(times_dict[reference_group_labels[0]])
266266
else:
267267
times_array = np.asarray(times_dict)
268268

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)