Skip to content

Commit 5e2b04f

Browse files
authored
Merge pull request #140 from scikit-learn-contrib/refacto_examples
Refacto examples
2 parents 5084217 + 507d6b3 commit 5e2b04f

File tree

6 files changed

+37
-18
lines changed

6 files changed

+37
-18
lines changed

HISTORY.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ History
1010
* Add MNIST example for classification
1111
* Add cross-conformal for classification
1212
* Add `notebooks` folder containing notebooks used for generating documentation tutorials
13-
* Uniformize the use of matrix k_ and add an argument "ensemble" to method
14-
"predict" in regression.py
13+
* Uniformize the use of matrix k_ and add an argument "ensemble" to method "predict" in regression.py
1514
* Add replication of the Chen Xu's tutorial testing Jackknife+aB vs Jackknife+
1615
* Add Jackknife+-after-Bootstrap documentation
1716
* Improve scikit-learn pipelines compatibility

examples/classification/1-quickstart/plot_comp_methods_on_2d_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
import matplotlib.pyplot as plt
5454

5555
from mapie.classification import MapieClassifier
56-
from mapie.metrics import classification_coverage_score
56+
from mapie.metrics import (
57+
classification_coverage_score,
58+
classification_mean_width_score
59+
)
5760
from mapie._typing import ArrayLike
5861

5962

@@ -261,7 +264,7 @@ def plot_results(
261264
for i, _ in enumerate(alpha_)
262265
]
263266
mean_width[method] = [
264-
y_ps_mapie[method][:, :, i].sum(axis=1).mean()
267+
classification_mean_width_score(y_ps_mapie[method][:, :, i])
265268
for i, _ in enumerate(alpha_)
266269
]
267270

examples/classification/2-advanced-analysis/plot_digits_classification.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import matplotlib.pyplot as plt
1919

2020
from mapie.classification import MapieClassifier
21-
from mapie.metrics import classification_coverage_score
21+
from mapie.metrics import (
22+
classification_coverage_score,
23+
classification_mean_width_score
24+
)
2225
from mapie._typing import ArrayLike
2326

2427

@@ -206,15 +209,21 @@ def get_datasets(dataset: Any) -> Tuple[
206209
# the "real" coverage obtained on the test set.
207210

208211
coverages1 = [
209-
classification_coverage_score(y_test1, y_ps1[:, :, ia])
210-
for ia, _ in enumerate(alpha)
212+
classification_coverage_score(y_test1, y_ps1[:, :, i])
213+
for i, _ in enumerate(alpha)
211214
]
212215
coverages2 = [
213-
classification_coverage_score(y_test2, y_ps2[:, :, ia])
214-
for ia, _ in enumerate(alpha)
216+
classification_coverage_score(y_test2, y_ps2[:, :, i])
217+
for i, _ in enumerate(alpha)
218+
]
219+
widths1 = [
220+
classification_mean_width_score(y_ps1[:, :, i])
221+
for i, _ in enumerate(alpha)
222+
]
223+
widths2 = [
224+
classification_mean_width_score(y_ps2[:, :, i])
225+
for i, _ in enumerate(alpha)
215226
]
216-
widths1 = [y_ps1[:, :, ia].sum(axis=1).mean() for ia, _ in enumerate(alpha)]
217-
widths2 = [y_ps2[:, :, ia].sum(axis=1).mean() for ia, _ in enumerate(alpha)]
218227

219228
_, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
220229
axes[0].set_xlabel("1 - alpha")

examples/regression/1-quickstart/plot_timeseries_example.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@
3030
Jackknife+ and CV+ methods.
3131
"""
3232
import pandas as pd
33-
from mapie.metrics import regression_coverage_score
34-
from mapie.regression import MapieRegressor
3533
from matplotlib import pylab as plt
3634
from scipy.stats import randint
3735
from sklearn.ensemble import RandomForestRegressor
3836
from sklearn.model_selection import RandomizedSearchCV, TimeSeriesSplit
3937

38+
from mapie.metrics import (
39+
regression_coverage_score,
40+
regression_mean_width_score
41+
)
42+
from mapie.regression import MapieRegressor
43+
4044
# Load input data and feature engineering
4145
demand_df = pd.read_csv(
4246
"../../data/demand_temperature.csv", parse_dates=True, index_col=0
@@ -86,7 +90,7 @@
8690
mapie.fit(X_train, y_train)
8791
y_pred, y_pis = mapie.predict(X_test, alpha=alpha)
8892
coverage = regression_coverage_score(y_test, y_pis[:, 0, 0], y_pis[:, 1, 0])
89-
width = (y_pis[:, 1, 0] - y_pis[:, 0, 0]).mean()
93+
width = regression_mean_width_score(y_pis[:, 0, 0], y_pis[:, 1, 0])
9094

9195
# Print results
9296
print(

examples/regression/3-scientific-articles/plot_barber2020_simulations.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
from sklearn.linear_model import LinearRegression
3535
from matplotlib import pyplot as plt
3636

37-
from mapie.metrics import regression_coverage_score
37+
from mapie.metrics import (
38+
regression_coverage_score,
39+
regression_mean_width_score
40+
)
3841
from mapie.regression import MapieRegressor
3942
from mapie._typing import ArrayLike
4043

@@ -116,12 +119,14 @@ def PIs_vs_dimensions(
116119
**params
117120
)
118121
mapie.fit(X_train, y_train)
119-
y_pred, y_pis = mapie.predict(X_test, alpha=alpha)
122+
_, y_pis = mapie.predict(X_test, alpha=alpha)
120123
coverage = regression_coverage_score(
121124
y_test, y_pis[:, 0, 0], y_pis[:, 1, 0]
122125
)
123126
results[strategy][dimension]["coverage"][trial] = coverage
124-
width_mean = (y_pis[:, 1, 0] - y_pis[:, 0, 0]).mean()
127+
width_mean = regression_mean_width_score(
128+
y_pis[:, 0, 0], y_pis[:, 1, 0]
129+
)
125130
results[strategy][dimension]["width_mean"][trial] = width_mean
126131
return results
127132

mapie/regression.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,5 +675,4 @@ def predict(
675675
).data
676676
if ensemble:
677677
y_pred = aggregate_all(self.agg_function, y_pred_multi)
678-
np.stack([y_pred_low, y_pred_up], axis=1)
679678
return y_pred, np.stack([y_pred_low, y_pred_up], axis=1)

0 commit comments

Comments
 (0)