Skip to content

Commit 7a582ff

Browse files
authored
Merge pull request #219 from uriahf/218-update-create_decision_curve-for-the-new-rtichoke-api
feat: migrate decision curves to new api
2 parents 632cf97 + 67033e8 commit 7a582ff

File tree

4 files changed

+218
-70
lines changed

4 files changed

+218
-70
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ dependencies = [
1111
"typing>=3.7.4.3",
1212
"polarstate==0.1.8",
1313
"marimo>=0.17.0",
14+
"pyarrow>=21.0.0",
1415
]
1516
name = "rtichoke"
16-
version = "0.1.21"
17+
version = "0.1.22"
1718
description = "interactive visualizations for performance of predictive models"
1819
readme = "README.md"
1920

src/rtichoke/helpers/plotly_helper_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ def _plot_rtichoke_curve_binary(
5656
stratified_by: str = "probability_threshold",
5757
curve: str = "roc",
5858
size: int = 600,
59+
min_p_threshold: float = 0,
60+
max_p_threshold: float = 1,
5961
) -> go.Figure:
6062
rtichoke_curve_list = _create_rtichoke_curve_list_binary(
6163
performance_data=performance_data,
6264
stratified_by=stratified_by,
6365
curve=curve,
6466
size=size,
67+
min_p_threshold=min_p_threshold,
68+
max_p_threshold=max_p_threshold,
6569
)
6670

6771
fig = _create_plotly_curve_binary(rtichoke_curve_list)

src/rtichoke/utility/decision.py

Lines changed: 92 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
"""
2-
A module for Summary Report
2+
A module for Decision Curves using Plotly helpers
33
"""
44

5-
from typing import Dict, List, Optional
6-
from pandas.core.frame import DataFrame
5+
from typing import Dict, List, Sequence, Union
76
from plotly.graph_objs._figure import Figure
8-
from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve
9-
from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve
7+
from rtichoke.helpers.plotly_helper_functions import (
8+
_create_rtichoke_plotly_curve_binary,
9+
_plot_rtichoke_curve_binary,
10+
)
11+
import numpy as np
12+
import polars as pl
1013

1114

1215
def create_decision_curve(
13-
probs: Dict[str, List[float]],
14-
reals: Dict[str, List[int]],
16+
probs: Dict[str, np.ndarray],
17+
reals: Union[np.ndarray, Dict[str, np.ndarray]],
1518
decision_type: str = "conventional",
1619
min_p_threshold: float = 0,
1720
max_p_threshold: float = 1,
1821
by: float = 0.01,
19-
stratified_by: str = "probability_threshold",
20-
size: Optional[int] = None,
22+
stratified_by: Sequence[str] = ["probability_threshold"],
23+
size: int = 600,
2124
color_values: List[str] = [
2225
"#1b9e77",
2326
"#d95f02",
@@ -40,38 +43,63 @@ def create_decision_curve(
4043
"#D1603D",
4144
"#585123",
4245
],
43-
url_api: str = "http://localhost:4242/",
4446
) -> Figure:
45-
"""Create Decision Curve
47+
"""Create Decision Curve.
4648
47-
Args:
48-
probs (Dict[str, List[float]]): _description_
49-
reals (Dict[str, List[int]]): _description_
50-
decision_type (str, optional): _description_. Defaults to "conventional".
51-
min_p_threshold (float, optional): _description_. Defaults to 0.
52-
max_p_threshold (float, optional): _description_. Defaults to 1.
53-
by (float, optional): _description_. Defaults to 0.01.
54-
stratified_by (str, optional): _description_. Defaults to "probability_threshold".
55-
size (Optional[int], optional): _description_. Defaults to None.
56-
color_values (List[str], optional): _description_. Defaults to None.
57-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
49+
Parameters
50+
----------
51+
probs : Dict[str, np.ndarray]
52+
Dictionary mapping a label or group name to an array of predicted
53+
probabilities for the positive class.
54+
reals : Union[np.ndarray, Dict[str, np.ndarray]]
55+
Ground-truth binary labels (0/1) as a single array, or a dictionary
56+
mapping the same label/group keys used in ``probs`` to arrays of
57+
ground-truth labels.
58+
decision_type : str, optional
59+
Either ``"conventional"`` (decision curve) or another value that
60+
implies the "interventions avoided" variant. Default is
61+
``"conventional"``.
62+
min_p_threshold : float, optional
63+
Minimum probability threshold to include in the curve. Default is 0.
64+
max_p_threshold : float, optional
65+
Maximum probability threshold to include in the curve. Default is 1.
66+
by : float, optional
67+
Resolution for probability thresholds when computing the curve
68+
(step size). Default is 0.01.
69+
stratified_by : Sequence[str], optional
70+
Sequence of column names to stratify the performance data by.
71+
Default is ["probability_threshold"].
72+
size : int, optional
73+
Plot size in pixels (width and height). Default is 600.
74+
color_values : List[str], optional
75+
List of color hex strings to use for the plotted lines. If not
76+
provided, a default palette is used.
5877
59-
Returns:
60-
Figure: _description_
78+
Returns
79+
-------
80+
Figure
81+
A Plotly ``Figure`` containing the Decision curve.
82+
83+
Notes
84+
-----
85+
The function selects the appropriate curve name based on
86+
``decision_type`` and delegates computation and plotting to
87+
``_create_rtichoke_plotly_curve_binary``. Additional keyword arguments
88+
(like ``min_p_threshold`` and ``max_p_threshold``) are forwarded to
89+
the helper.
6190
"""
6291
if decision_type == "conventional":
6392
curve = "decision"
6493
else:
6594
curve = "interventions avoided"
6695

67-
fig = create_rtichoke_curve(
96+
fig = _create_rtichoke_plotly_curve_binary(
6897
probs,
6998
reals,
7099
by=by,
71100
stratified_by=stratified_by,
72101
size=size,
73102
color_values=color_values,
74-
url_api=url_api,
75103
curve=curve,
76104
min_p_threshold=min_p_threshold,
77105
max_p_threshold=max_p_threshold,
@@ -80,59 +108,55 @@ def create_decision_curve(
80108

81109

82110
def plot_decision_curve(
83-
performance_data: DataFrame,
84-
decision_type: str,
85-
min_p_threshold: int = 0,
86-
max_p_threshold: int = 1,
87-
size: Optional[int] = None,
88-
color_values: List[str] = [
89-
"#1b9e77",
90-
"#d95f02",
91-
"#7570b3",
92-
"#e7298a",
93-
"#07004D",
94-
"#E6AB02",
95-
"#FE5F55",
96-
"#54494B",
97-
"#006E90",
98-
"#BC96E6",
99-
"#52050A",
100-
"#1F271B",
101-
"#BE7C4D",
102-
"#63768D",
103-
"#08A045",
104-
"#320A28",
105-
"#82FF9E",
106-
"#2176FF",
107-
"#D1603D",
108-
"#585123",
109-
],
110-
url_api: str = "http://localhost:4242/",
111+
performance_data: pl.DataFrame,
112+
decision_type: str = "conventional",
113+
min_p_threshold: float = 0,
114+
max_p_threshold: float = 1,
115+
stratified_by: Sequence[str] = ["probability_threshold"],
116+
size: int = 600,
111117
) -> Figure:
112-
"""Plot Decision Curve
118+
"""Plot Decision Curve from performance data.
119+
120+
Parameters
121+
----------
122+
performance_data : pl.DataFrame
123+
A Polars DataFrame containing performance metrics for the Decision
124+
curve. Expected columns include (but may not be limited to)
125+
``probability_threshold`` and decision-curve metrics, plus any
126+
stratification columns.
127+
decision_type : str
128+
``"conventional"`` for decision curves, otherwise the
129+
"interventions avoided" variant will be used.
130+
min_p_threshold : float, optional
131+
Minimum probability threshold to include in the curve. Default is 0.
132+
max_p_threshold : float, optional
133+
Maximum probability threshold to include in the curve. Default is 1.
134+
stratified_by : Sequence[str], optional
135+
Sequence of column names used for stratification in the
136+
``performance_data``. Default is ["probability_threshold"].
137+
size : int, optional
138+
Plot size in pixels (width and height). Default is 600.
113139
114-
Args:
115-
performance_data (DataFrame): _description_
116-
decision_type (str): _description_
117-
min_p_threshold (int, optional): _description_. Defaults to 0.
118-
max_p_threshold (int, optional): _description_. Defaults to 1.
119-
size (Optional[int], optional): _description_. Defaults to None.
120-
color_values (List[str], optional): _description_. Defaults to None.
121-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
140+
Returns
141+
-------
142+
Figure
143+
A Plotly ``Figure`` containing the Decision plot.
122144
123-
Returns:
124-
Figure: _description_
145+
Notes
146+
-----
147+
This function wraps ``_plot_rtichoke_curve_binary`` to produce a
148+
ready-to-render Plotly figure from precomputed performance data.
149+
Additional keyword arguments (``min_p_threshold``, ``max_p_threshold``)
150+
are forwarded to the helper.
125151
"""
126152
if decision_type == "conventional":
127153
curve = "decision"
128154
else:
129155
curve = "interventions avoided"
130156

131-
fig = plot_rtichoke_curve(
157+
fig = _plot_rtichoke_curve_binary(
132158
performance_data,
133159
size=size,
134-
color_values=color_values,
135-
url_api=url_api,
136160
curve=curve,
137161
min_p_threshold=min_p_threshold,
138162
max_p_threshold=max_p_threshold,

0 commit comments

Comments
 (0)