Skip to content

Commit d8665e4

Browse files
REFACTO: in split setting, remove checking NaNs and irrelevant aggregation to avoid triggering unwanted warnings (#586)
* REFACTO: in split setting, remove checking NaNs and irrelevant aggregation to avoid triggering unwanted warnings
1 parent abfc309 commit d8665e4

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

mapie/estimator/regressor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,12 @@ def predict_calib(
402402
predictions[i], dtype=float
403403
)
404404
self.k_[ind, i] = 1
405-
check_nan_in_aposteriori_prediction(pred_matrix)
406405

407-
y_pred = aggregate_all(self.agg_function, pred_matrix)
406+
if self.use_split_method_:
407+
y_pred = pred_matrix.flatten()
408+
else:
409+
check_nan_in_aposteriori_prediction(pred_matrix)
410+
y_pred = aggregate_all(self.agg_function, pred_matrix)
408411

409412
return y_pred
410413

mapie/tests/test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def test_not_enough_resamplings() -> None:
701701
"""
702702
with pytest.warns(UserWarning, match=r"WARNING: at least one point of*"):
703703
mapie_reg = MapieRegressor(
704-
cv=Subsample(n_resamplings=1), agg_function="mean"
704+
cv=Subsample(n_resamplings=2, random_state=0), agg_function="mean"
705705
)
706706
mapie_reg.fit(X, y)
707707

mapie/tests/test_time_series_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ def test_not_enough_resamplings() -> None:
318318
match=r"WARNING: at least one point of*"
319319
):
320320
mapie_ts_reg = MapieTimeSeriesRegressor(
321-
cv=BlockBootstrap(n_resamplings=1, n_blocks=1), agg_function="mean"
321+
cv=BlockBootstrap(n_resamplings=2, n_blocks=1, random_state=0),
322+
agg_function="mean"
322323
)
323324
mapie_ts_reg.fit(X, y)
324325

0 commit comments

Comments
 (0)