Skip to content

Commit 04f4627

Browse files
authored
Merge pull request #267 from uriahf/jules/ty-fixes-322255806096583053
Pass ty check and update AGENTS.md
2 parents 99c0487 + c5ad925 commit 04f4627

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

AGENTS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,11 @@ Before committing, please ensure that the pre-commit hooks pass. You can run the
6767
## Documentation
6868

6969
The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually.
70+
71+
## Type Checking
72+
73+
This project uses `ty` for type checking. To check for type errors, run the following command:
74+
75+
```bash
76+
uv run ty check src tests
77+
```

src/rtichoke/calibration/calibration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,14 @@ def _create_plotly_curve_from_calibration_curve_list_times(
206206
num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"])
207207

208208
for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]):
209+
visibility = [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]))
210+
for j in range(num_traces_per_horizon):
211+
visibility[i * num_traces_per_horizon + j] = True
209212
step = dict(
210213
method="restyle",
211-
args=[{"visible": [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]))}],
214+
args=[{"visible": visibility}],
212215
label=str(horizon),
213216
)
214-
for j in range(num_traces_per_horizon):
215-
step["args"][0]["visible"][i * num_traces_per_horizon + j] = True
216217
steps.append(step)
217218

218219
sliders = [dict(
@@ -726,7 +727,7 @@ def _add_hover_text_to_calibration_data(
726727
deciles_dat: pl.DataFrame,
727728
smooth_dat: pl.DataFrame,
728729
performance_type: str,
729-
) -> (pl.DataFrame, pl.DataFrame):
730+
) -> tuple[pl.DataFrame, pl.DataFrame]:
730731
"""Adds hover text to the deciles and smooth dataframes."""
731732
if performance_type != "one model":
732733
deciles_dat = deciles_dat.with_columns(

src/rtichoke/helpers/plotly_helper_functions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,15 @@ def _htext(title: pl.Expr) -> pl.Expr:
681681
(pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold)
682682
)
683683

684+
return pl.DataFrame(
685+
schema={
686+
"reference_group": pl.Utf8,
687+
"x": pl.Float64,
688+
"y": pl.Float64,
689+
"text": pl.Utf8,
690+
}
691+
)
692+
684693

685694
def create_non_interactive_curve_polars(
686695
performance_data_ready_for_curve, reference_group_color, reference_group
@@ -1157,7 +1166,7 @@ def _add_hover_text_to_performance_data(
11571166
)
11581167

11591168
return performance_data.with_columns(
1160-
[pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
1169+
[pl.col(pl.Float64).round(3), hover_text_expr.alias("text")]
11611170
)
11621171

11631172

src/rtichoke/helpers/sandbox_observable_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:
152152

153153
labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)]
154154

155-
strata_labels = np.array([labels[i] for i in bin_idx], dtype=object)
155+
strata_labels = np.array(labels)[bin_idx]
156156

157157
columns_to_add.append(
158158
pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels))

0 commit comments

Comments
 (0)