diff --git a/pyproject.toml b/pyproject.toml index 6a17fef..5645f67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "polars>=1.31.0", ] name = "rtichoke" -version = "0.1.27" +version = "0.1.28" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 0e03df0..567c05c 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,7 +2,7 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, cast # import pandas as pd import plotly.graph_objects as go @@ -247,6 +247,7 @@ def _create_plotly_curve_from_calibration_curve_list_times( }, barmode="overlay", plot_bgcolor="rgba(0, 0, 0, 0)", + paper_bgcolor="rgba(0, 0, 0, 0)", legend={ "orientation": "h", "xanchor": "center", @@ -285,6 +286,7 @@ def _create_plotly_curve_from_calibration_curve_list( "yaxis": {"showgrid": False}, "barmode": "overlay", "plot_bgcolor": "rgba(0, 0, 0, 0)", + "paper_bgcolor": "rgba(0, 0, 0, 0)", "legend": { "orientation": "h", "xanchor": "center", @@ -470,7 +472,7 @@ def _make_deciles_dat_binary( if isinstance(reals, dict): reference_groups_keys = list(reals.keys()) y_list = [ - np.asarray(reals[reference_group]).ravel() + np.asarray(reals[str(reference_group)]).ravel() for reference_group in reference_groups_keys ] lengths = np.array([len(y) for y in y_list], dtype=np.int64) @@ -533,7 +535,7 @@ def _make_deciles_dat_binary( ( (pl.col("prob").rank("ordinal").over(["reference_group", "model"]) - 1) * n_bins - // pl.count().over(["reference_group", "model"]) + // pl.len().over(["reference_group", "model"]) + 1 ).alias("decile"), ] @@ -602,7 +604,7 @@ def _create_calibration_curve_list( reference_data = _create_reference_data_for_calibration_curve() - reference_groups = deciles_data["reference_group"].unique().to_list() + reference_groups = list(probs.keys()) colors_dictionary = _create_colors_dictionary_for_calibration( reference_groups, color_values, performance_type @@ -689,7 +691,9 @@ def process_single_array(p, r, group_name): for group_name in reals.keys(): if group_name in probs: frame = process_single_array( - probs[group_name], reals[group_name], group_name + probs[str(group_name)], + reals[str(group_name)], + str(group_name), ) smooth_frames.append(frame) @@ -856,8 +860,21 @@ def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float if deciles_dat.height == 1: lower_bound, upper_bound = 0.0, 1.0 else: - lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))) - upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max())) + lower_bound = float( + max( + 0, + min( + cast(float, deciles_dat["x"].min()), + cast(float, deciles_dat["y"].min()), + ), + ) + ) + upper_bound = float( + max( + cast(float, deciles_dat["x"].max()), + cast(float, deciles_dat["y"].max()), + ) + ) return [ lower_bound - (upper_bound - lower_bound) * 0.05, @@ -1101,7 +1118,7 @@ def _create_calibration_curve_list_times( ) reference_data = _create_reference_data_for_calibration_curve() - reference_groups = deciles_dat_final["reference_group"].unique().to_list() + reference_groups = list(probs.keys()) colors_dictionary = _create_colors_dictionary_for_calibration( reference_groups, color_values, performance_type ) diff --git a/src/rtichoke/processing/exported_functions.py b/src/rtichoke/processing/exported_functions.py index 778ad91..f9aaddb 100644 --- a/src/rtichoke/processing/exported_functions.py +++ b/src/rtichoke/processing/exported_functions.py @@ -148,6 +148,7 @@ def create_plotly_curve(rtichoke_curve_dict): "y": 0, "steps": [], } + sliders_dict["steps"] = [] for k in range( len( diff --git a/src/rtichoke/processing/plotly_helper_functions.py b/src/rtichoke/processing/plotly_helper_functions.py index 074fc52..07b6bc2 100644 --- a/src/rtichoke/processing/plotly_helper_functions.py +++ b/src/rtichoke/processing/plotly_helper_functions.py @@ -5,7 +5,7 @@ import plotly.graph_objects as go import polars as pl import math -from typing import Any, Dict, Union, Sequence +from typing import Any, Dict, Union, Sequence, cast import numpy as np from rtichoke.performance_data.performance_data import prepare_performance_data from rtichoke.performance_data.performance_data_times import ( @@ -329,8 +329,8 @@ def _create_reference_lines_data( # random-guess (y=1 unless all p==0 -> NaN) all_zero = ( aj_df["p"].len() > 0 - and float(aj_df["p"].max()) == 0.0 - and float(aj_df["p"].min()) == 0.0 + and float(cast(float, aj_df["p"].max())) == 0.0 + and float(cast(float, aj_df["p"].min())) == 0.0 ) rand_y = pl.Series( np.full(len(x_s), np.nan) if all_zero else np.ones(len(x_s)), @@ -992,7 +992,7 @@ def _check_if_multiple_populations_are_being_validated_times( ] .max() ) - return max_val is not None and max_val > 1 + return max_val is not None and float(cast(float, max_val)) > 1 def _check_if_multiple_populations_are_being_validated( @@ -1977,10 +1977,21 @@ def _create_curve_layout( "b": max(80, base_pad.get("b", 0)), **base_pad, } + xaxis: dict[str, Any] = {"showgrid": False} + yaxis: dict[str, Any] = {"showgrid": False} + + if axes_ranges is not None: + xaxis["range"] = axes_ranges["xaxis"] + yaxis["range"] = axes_ranges["yaxis"] + + if x_label: + xaxis["title"] = {"text": x_label} + if y_label: + yaxis["title"] = {"text": y_label} curve_layout = { - "xaxis": {"showgrid": False}, - "yaxis": {"showgrid": False}, + "xaxis": xaxis, + "yaxis": yaxis, "template": "plotly", "plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)", @@ -2014,15 +2025,6 @@ def _create_curve_layout( "modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)}, } - if axes_ranges is not None: - curve_layout["xaxis"]["range"] = axes_ranges["xaxis"] - curve_layout["yaxis"]["range"] = axes_ranges["yaxis"] - - if x_label: - curve_layout["xaxis"]["title"] = {"text": x_label} - if y_label: - curve_layout["yaxis"]["title"] = {"text": y_label} - return curve_layout diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py index 4e4339a..ba162ac 100644 --- a/src/rtichoke/processing/transforms.py +++ b/src/rtichoke/processing/transforms.py @@ -67,13 +67,13 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: # Identify id_vars and value_vars - id_vars = [col for col in data.columns if not col.startswith("strata_")] - value_vars = [col for col in data.columns if col.startswith("strata_")] + index_cols = [col for col in data.columns if not col.startswith("strata_")] + on_cols = [col for col in data.columns if col.startswith("strata_")] - # Perform the melt (equivalent to pandas.melt) - data_long = data.melt( - id_vars=id_vars, - value_vars=value_vars, + # Perform the unpivot (equivalent to pandas.melt) + data_long = data.unpivot( + index=index_cols, + on=on_cols, variable_name="stratified_by", value_name="strata", ) @@ -257,12 +257,12 @@ def _create_list_data_to_adjust( probs_array = np.asarray(probs_dict[reference_group_labels[0]]) if isinstance(reals_dict, dict): - reals_array = np.asarray(reals_dict[0]) + reals_array = np.asarray(reals_dict[reference_group_labels[0]]) else: reals_array = np.asarray(reals_dict) if isinstance(times_dict, dict): - times_array = np.asarray(times_dict[0]) + times_array = np.asarray(times_dict[reference_group_labels[0]]) else: times_array = np.asarray(times_dict) diff --git a/uv.lock b/uv.lock index a345597..971a9da 100644 --- a/uv.lock +++ b/uv.lock @@ -5143,7 +5143,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.27" +version = "0.1.28" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },