Skip to content

Commit f622549

Browse files
Merge pull request #137 from scikit-learn-contrib/include-crossval-for-cumulatedscore
Include crossval for cumulatedscore
2 parents 5e2b04f + 8c60eea commit f622549

File tree

5 files changed

+549
-322
lines changed

5 files changed

+549
-322
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,6 @@ target/
7979
# Images
8080
*.png
8181
*.jpeg
82+
83+
# ZIP files
84+
*.zip

doc/images/quickstart_1.png

0 Bytes
Loading

examples/classification/2-advanced-analysis/plot_crossconformal.py

Lines changed: 129 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
"""
2727

2828

29-
from typing import Dict, Any
29+
from typing import Dict, Any, Optional, Union
30+
from typing_extensions import TypedDict
3031
import numpy as np
32+
import pandas as pd
3133
import matplotlib.pyplot as plt
3234
from sklearn.naive_bayes import GaussianNB
3335
from sklearn.model_selection import KFold
@@ -138,8 +140,7 @@
138140
axs[i].set_title(f"split={key}\nquantile={mapie.quantiles_[9]:.3f}")
139141
plt.suptitle(
140142
"Distribution of scores on each calibration fold for the "
141-
f"{methods[0]} method",
142-
y=1.04
143+
f"{methods[0]} method"
143144
)
144145
plt.show()
145146

@@ -186,8 +187,9 @@ def plot_results(
186187
axs[i].set_title(f"coverage = {coverage:.3f}")
187188
plt.suptitle(
188189
"Number of labels in prediction sets "
189-
f"for the {method} method", y=1.04
190+
f"for the {method} method"
190191
)
192+
plt.show()
191193

192194

193195
##############################################################################
@@ -224,27 +226,33 @@ def plot_coverage_width(
224226
alpha: float,
225227
coverages: ArrayLike,
226228
widths: ArrayLike,
227-
method: str
229+
method: str,
230+
comp: str = "split"
228231
) -> None:
232+
if comp == "split":
233+
legends = [f"Split {i + 1}" for i, _ in enumerate(coverages)]
234+
else:
235+
legends = ["Mean", "Crossval"]
229236
_, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
230237
axes[0].set_xlabel("1 - alpha")
231238
axes[0].set_ylabel("Effective coverage")
232239
for i, coverage in enumerate(coverages):
233-
axes[0].plot(1 - alpha, coverage, label=f"Split {i + 1}")
240+
axes[0].plot(1 - alpha, coverage, label=legends[i])
234241
axes[0].plot([0, 1], [0, 1], ls="--", color="k")
235242
axes[0].legend()
236243
axes[1].set_xlabel("1 - alpha")
237244
axes[1].set_ylabel("Average of prediction set sizes")
238245
for i, width in enumerate(widths):
239-
axes[1].plot(1 - alpha, width, label=f"Split {i+1}")
246+
axes[1].plot(1 - alpha, width, label=legends[i])
240247
axes[1].legend()
241248
plt.suptitle(
242249
"Effective coverage and prediction set size "
243-
f"for the {method} method", y=1.04
250+
f"for the {method} method"
244251
)
252+
plt.show()
245253

246254

247-
coverages = np.array(
255+
split_coverages = np.array(
248256
[
249257
[
250258
[
@@ -256,7 +264,7 @@ def plot_coverage_width(
256264
]
257265
)
258266

259-
widths = np.array(
267+
split_widths = np.array(
260268
[
261269
[
262270
[
@@ -268,9 +276,13 @@ def plot_coverage_width(
268276
]
269277
)
270278

271-
plot_coverage_width(alpha, coverages[0], widths[0], "score")
279+
plot_coverage_width(
280+
alpha, split_coverages[0], split_widths[0], "score"
281+
)
272282

273-
plot_coverage_width(alpha, coverages[1], widths[1], "cumulated_score")
283+
plot_coverage_width(
284+
alpha, split_coverages[1], split_widths[1], "cumulated_score"
285+
)
274286

275287

276288
##############################################################################
@@ -306,82 +318,101 @@ def plot_coverage_width(
306318
# When estimating the prediction sets, we define how the scores are aggregated
307319
# with the ``agg_scores`` attribute.
308320

321+
Params = TypedDict(
322+
"Params",
323+
{
324+
"method": str,
325+
"cv": Optional[Union[int, str]],
326+
"random_state": Optional[int]
327+
}
328+
)
329+
ParamsPredict = TypedDict(
330+
"ParamsPredict",
331+
{
332+
"include_last_label": Union[bool, str],
333+
"agg_scores": str
334+
}
335+
)
309336

310337
kf = KFold(n_splits=5, shuffle=True)
311-
mapie_clf = MapieClassifier(estimator=clf, cv=kf, method="score")
312-
mapie_clf.fit(X_train, y_train)
313338

314-
_, y_ps_score_mean = mapie_clf.predict(
315-
X_test_distrib,
316-
alpha=alpha,
317-
agg_scores="mean"
318-
)
319-
_, y_ps_score_crossval = mapie_clf.predict(
320-
X_test_distrib,
321-
alpha=alpha,
322-
agg_scores="crossval"
323-
)
339+
STRATEGIES = {
340+
"score_cv_mean": (
341+
Params(method="score", cv=kf, random_state=42),
342+
ParamsPredict(include_last_label=False, agg_scores="mean")
343+
),
344+
"score_cv_crossval": (
345+
Params(method="score", cv=kf, random_state=42),
346+
ParamsPredict(include_last_label=False, agg_scores="crossval")
347+
),
348+
"cum_score_cv_mean": (
349+
Params(method="cumulated_score", cv=kf, random_state=42),
350+
ParamsPredict(include_last_label="randomized", agg_scores="mean")
351+
),
352+
"cum_score_cv_crossval": (
353+
Params(method="cumulated_score", cv=kf, random_state=42),
354+
ParamsPredict(include_last_label='randomized', agg_scores="crossval")
355+
)
356+
}
357+
358+
y_preds, y_ps = {}, {}
359+
for strategy, params in STRATEGIES.items():
360+
args_init, args_predict = STRATEGIES[strategy]
361+
mapie_clf = MapieClassifier(**args_init)
362+
mapie_clf.fit(X_train, y_train)
363+
y_preds[strategy], y_ps[strategy] = mapie_clf.predict(
364+
X_test_distrib,
365+
alpha=alpha,
366+
**args_predict
367+
)
324368

325369

326370
##############################################################################
327371
# Next, we estimate the coverages and widths of prediction sets for both
328-
# aggregation methods.
372+
# aggregation strategies and both methods.
373+
# We also estimate the "violation" score defined as the absolute difference
374+
# between the effective coverage and the target coverage averaged over all
375+
# alpha values.
329376

377+
coverages, widths, violations = {}, {}, {}
330378

331-
coverages_score_mean = np.array(
332-
[
333-
classification_coverage_score(
334-
y_test_distrib,
335-
y_ps_score_mean[:, :, ia]
336-
) for ia, _ in enumerate(alpha)
337-
]
338-
)
339-
340-
widths_score_mean = np.array(
341-
[
342-
classification_mean_width_score(y_ps_score_mean[:, :, ia])
343-
for ia, _ in enumerate(alpha)
344-
]
345-
)
346-
347-
coverages_score_crossval = np.array(
348-
[
349-
classification_coverage_score(
350-
y_test_distrib,
351-
y_ps_score_crossval[:, :, ia]
352-
) for ia, _ in enumerate(alpha)
353-
]
354-
)
355-
356-
widths_score_crossval = np.array(
357-
[
358-
classification_mean_width_score(y_ps_score_crossval[:, :, ia])
359-
for ia, _ in enumerate(alpha)
360-
]
361-
)
379+
for strategy, y_ps_ in y_ps.items():
380+
coverages[strategy] = np.array(
381+
[
382+
classification_coverage_score(
383+
y_test_distrib,
384+
y_ps_[:, :, ia]
385+
) for ia, _ in enumerate(alpha)
386+
]
387+
)
388+
widths[strategy] = np.array(
389+
[
390+
classification_mean_width_score(y_ps_[:, :, ia])
391+
for ia, _ in enumerate(alpha)
392+
]
393+
)
394+
violations[strategy] = np.abs(coverages[strategy] - (1 - alpha)).mean()
362395

363396

364397
##############################################################################
365398
# Next, we visualize their coverages and prediction set sizes as function of
366399
# the `alpha` parameter.
367400

401+
plot_coverage_width(
402+
alpha,
403+
[coverages["score_cv_mean"], coverages["score_cv_crossval"]],
404+
[widths["score_cv_mean"], widths["score_cv_crossval"]],
405+
"score",
406+
comp="mean"
407+
)
368408

369-
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
370-
axes[0].set_xlabel("1 - alpha")
371-
axes[0].set_ylabel("Effective coverage")
372-
for i, coverage in enumerate([coverages_score_mean, coverages_score_crossval]):
373-
axes[0].plot(1 - alpha, coverage)
374-
axes[0].plot([0, 1], [0, 1], ls="--", color="k")
375-
axes[1].set_xlabel("1 - alpha")
376-
axes[1].set_ylabel("Average of prediction set sizes")
377-
for i, widths in enumerate([widths_score_mean, widths_score_crossval]):
378-
axes[1].plot(1 - alpha, widths)
379-
axes[1].legend(["mean", "crossval"], loc=[1, 0])
380-
plt.suptitle(
381-
"Effective coverage and prediction set sizes for ``mean`` "
382-
"aggregation method"
409+
plot_coverage_width(
410+
alpha,
411+
[coverages["cum_score_cv_mean"], coverages["cum_score_cv_mean"]],
412+
[widths["cum_score_cv_crossval"], widths["cum_score_cv_crossval"]],
413+
"cumulated_score",
414+
comp="mean"
383415
)
384-
plt.show()
385416

386417

387418
##############################################################################
@@ -391,23 +422,32 @@ def plot_coverage_width(
391422
#
392423
# The calibration plots obtained with the cross-conformal methods seem to be
393424
# more robust than with the split-conformal used above. Let's check this first
394-
# impression by estimating the deviation from the "perfect" coverage as
395-
# function of the `alpha` parameter.
396-
425+
# impression by comparing the violation of the effective coverage from the
426+
# target coverage between the cross-conformal and split-conformal methods.
397427

398-
plt.figure(figsize=(10, 5))
399-
label = f"Cross-conf: {np.abs(coverages_score_mean - (1 - alpha)).mean(): .3f}"
400-
plt.plot(
401-
1 - alpha,
402-
coverages_score_mean - (1 - alpha),
403-
color="k",
404-
label=label
428+
violations_df = pd.DataFrame(
429+
index=["score", "cumulated_score"],
430+
columns=["cv_mean", "cv_crossval", "splits"]
405431
)
406-
for i, coverage in enumerate(coverages[0]):
407-
label = f"Split {i + 1}: {np.abs(coverage - (1 - alpha)).mean(): .3f}"
408-
plt.plot(1 - alpha, coverage - (1 - alpha), label=label)
409-
plt.axhline(0, color="k", ls=":")
410-
plt.xlabel("1 - alpha")
411-
plt.ylabel("Deviation from perfect calibration")
412-
plt.legend(loc=[1, 0])
413-
plt.show()
432+
violations_df.loc["score", "cv_mean"] = violations["score_cv_mean"]
433+
violations_df.loc["score", "cv_crossval"] = violations["score_cv_crossval"]
434+
violations_df.loc["score", "splits"] = np.stack(
435+
[
436+
np.abs(cov - (1 - alpha)).mean()
437+
for cov in split_coverages[0]
438+
]
439+
).mean()
440+
violations_df.loc["cumulated_score", "cv_mean"] = (
441+
violations["cum_score_cv_mean"]
442+
)
443+
violations_df.loc["cumulated_score", "cv_crossval"] = (
444+
violations["cum_score_cv_crossval"]
445+
)
446+
violations_df.loc["cumulated_score", "splits"] = np.stack(
447+
[
448+
np.abs(cov - (1 - alpha)).mean()
449+
for cov in split_coverages[1]
450+
]
451+
).mean()
452+
453+
print(violations_df)

0 commit comments

Comments
 (0)