Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"polars>=1.31.0",
]
name = "rtichoke"
version = "0.1.27"
version = "0.1.28"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"

Expand Down
33 changes: 25 additions & 8 deletions src/rtichoke/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A module for Calibration Curves
"""

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, cast

# import pandas as pd
import plotly.graph_objects as go
Expand Down Expand Up @@ -247,6 +247,7 @@ def _create_plotly_curve_from_calibration_curve_list_times(
},
barmode="overlay",
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)",
legend={
"orientation": "h",
"xanchor": "center",
Expand Down Expand Up @@ -285,6 +286,7 @@ def _create_plotly_curve_from_calibration_curve_list(
"yaxis": {"showgrid": False},
"barmode": "overlay",
"plot_bgcolor": "rgba(0, 0, 0, 0)",
"paper_bgcolor": "rgba(0, 0, 0, 0)",
"legend": {
"orientation": "h",
"xanchor": "center",
Expand Down Expand Up @@ -470,7 +472,7 @@ def _make_deciles_dat_binary(
if isinstance(reals, dict):
reference_groups_keys = list(reals.keys())
y_list = [
np.asarray(reals[reference_group]).ravel()
np.asarray(reals[str(reference_group)]).ravel()
for reference_group in reference_groups_keys
]
lengths = np.array([len(y) for y in y_list], dtype=np.int64)
Expand Down Expand Up @@ -533,7 +535,7 @@ def _make_deciles_dat_binary(
(
(pl.col("prob").rank("ordinal").over(["reference_group", "model"]) - 1)
* n_bins
// pl.count().over(["reference_group", "model"])
// pl.len().over(["reference_group", "model"])
+ 1
).alias("decile"),
]
Expand Down Expand Up @@ -602,7 +604,7 @@ def _create_calibration_curve_list(

reference_data = _create_reference_data_for_calibration_curve()

reference_groups = deciles_data["reference_group"].unique().to_list()
reference_groups = list(probs.keys())

colors_dictionary = _create_colors_dictionary_for_calibration(
reference_groups, color_values, performance_type
Expand Down Expand Up @@ -689,7 +691,9 @@ def process_single_array(p, r, group_name):
for group_name in reals.keys():
if group_name in probs:
frame = process_single_array(
probs[group_name], reals[group_name], group_name
probs[str(group_name)],
reals[str(group_name)],
str(group_name),
)
smooth_frames.append(frame)

Expand Down Expand Up @@ -856,8 +860,21 @@ def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float
if deciles_dat.height == 1:
lower_bound, upper_bound = 0.0, 1.0
else:
lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min())))
upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max()))
lower_bound = float(
max(
0,
min(
cast(float, deciles_dat["x"].min()),
cast(float, deciles_dat["y"].min()),
),
)
)
upper_bound = float(
max(
cast(float, deciles_dat["x"].max()),
cast(float, deciles_dat["y"].max()),
)
)

return [
lower_bound - (upper_bound - lower_bound) * 0.05,
Expand Down Expand Up @@ -1101,7 +1118,7 @@ def _create_calibration_curve_list_times(
)

reference_data = _create_reference_data_for_calibration_curve()
reference_groups = deciles_dat_final["reference_group"].unique().to_list()
reference_groups = list(probs.keys())
colors_dictionary = _create_colors_dictionary_for_calibration(
reference_groups, color_values, performance_type
)
Expand Down
1 change: 1 addition & 0 deletions src/rtichoke/processing/exported_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def create_plotly_curve(rtichoke_curve_dict):
"y": 0,
"steps": [],
}
sliders_dict["steps"] = []

for k in range(
len(
Expand Down
32 changes: 17 additions & 15 deletions src/rtichoke/processing/plotly_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import plotly.graph_objects as go
import polars as pl
import math
from typing import Any, Dict, Union, Sequence
from typing import Any, Dict, Union, Sequence, cast
import numpy as np
from rtichoke.performance_data.performance_data import prepare_performance_data
from rtichoke.performance_data.performance_data_times import (
Expand Down Expand Up @@ -329,8 +329,8 @@ def _create_reference_lines_data(
# random-guess (y=1 unless all p==0 -> NaN)
all_zero = (
aj_df["p"].len() > 0
and float(aj_df["p"].max()) == 0.0
and float(aj_df["p"].min()) == 0.0
and float(cast(float, aj_df["p"].max())) == 0.0
and float(cast(float, aj_df["p"].min())) == 0.0
)
rand_y = pl.Series(
np.full(len(x_s), np.nan) if all_zero else np.ones(len(x_s)),
Expand Down Expand Up @@ -992,7 +992,7 @@ def _check_if_multiple_populations_are_being_validated_times(
]
.max()
)
return max_val is not None and max_val > 1
return max_val is not None and float(cast(float, max_val)) > 1


def _check_if_multiple_populations_are_being_validated(
Expand Down Expand Up @@ -1977,10 +1977,21 @@ def _create_curve_layout(
"b": max(80, base_pad.get("b", 0)),
**base_pad,
}
xaxis: dict[str, Any] = {"showgrid": False}
yaxis: dict[str, Any] = {"showgrid": False}

if axes_ranges is not None:
xaxis["range"] = axes_ranges["xaxis"]
yaxis["range"] = axes_ranges["yaxis"]

if x_label:
xaxis["title"] = {"text": x_label}
if y_label:
yaxis["title"] = {"text": y_label}

curve_layout = {
"xaxis": {"showgrid": False},
"yaxis": {"showgrid": False},
"xaxis": xaxis,
"yaxis": yaxis,
"template": "plotly",
"plot_bgcolor": "rgba(0, 0, 0, 0)",
"paper_bgcolor": "rgba(0, 0, 0, 0)",
Expand Down Expand Up @@ -2014,15 +2025,6 @@ def _create_curve_layout(
"modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)},
}

if axes_ranges is not None:
curve_layout["xaxis"]["range"] = axes_ranges["xaxis"]
curve_layout["yaxis"]["range"] = axes_ranges["yaxis"]

if x_label:
curve_layout["xaxis"]["title"] = {"text": x_label}
if y_label:
curve_layout["yaxis"]["title"] = {"text": y_label}

return curve_layout


Expand Down
16 changes: 8 additions & 8 deletions src/rtichoke/processing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:

def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame:
# Identify id_vars and value_vars
id_vars = [col for col in data.columns if not col.startswith("strata_")]
value_vars = [col for col in data.columns if col.startswith("strata_")]
index_cols = [col for col in data.columns if not col.startswith("strata_")]
on_cols = [col for col in data.columns if col.startswith("strata_")]

# Perform the melt (equivalent to pandas.melt)
data_long = data.melt(
id_vars=id_vars,
value_vars=value_vars,
# Perform the unpivot (equivalent to pandas.melt)
data_long = data.unpivot(
index=index_cols,
on=on_cols,
variable_name="stratified_by",
value_name="strata",
)
Expand Down Expand Up @@ -257,12 +257,12 @@ def _create_list_data_to_adjust(
probs_array = np.asarray(probs_dict[reference_group_labels[0]])

if isinstance(reals_dict, dict):
reals_array = np.asarray(reals_dict[0])
reals_array = np.asarray(reals_dict[reference_group_labels[0]])
else:
reals_array = np.asarray(reals_dict)

if isinstance(times_dict, dict):
times_array = np.asarray(times_dict[0])
times_array = np.asarray(times_dict[reference_group_labels[0]])
else:
times_array = np.asarray(times_dict)

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.