Skip to content

Commit 1b0e00d

Browse files
author
mata
committed
adding tests for the mondrian method.
1 parent 53ec51d commit 1b0e00d

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: tests doc build
22

33
lint:
4-
flake8 . --exclude=doc mapieenv
4+
flake8 . --exclude=doc,mapieenv
55

66
type-check:
77
mypy mapie

mapie/tests/test_classification.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@
292292
"naive": 5 / 9,
293293
"top_k": 1,
294294
"raps": 1,
295-
"raps_randomized": 8/9
295+
"raps_randomized": 8/9,
296+
"mondrian": 1
296297
}
297298

298299
X_toy = np.arange(9).reshape(-1, 1)
@@ -475,6 +476,17 @@
475476
[False, True, True],
476477
[False, False, True],
477478
],
479+
"mondrian": [
480+
[True, False, False],
481+
[True, False, False],
482+
[True, True, False],
483+
[True, True, True],
484+
[True, True, True],
485+
[True, True, True],
486+
[False, True, True],
487+
[False, True, True],
488+
[False, False, True],
489+
],
478490
}
479491

480492
REGULARIZATION_PARAMETERS = [
@@ -875,7 +887,8 @@ def test_toy_dataset_predictions(strategy: str) -> None:
875887
alpha=0.5,
876888
include_last_label=args_predict["include_last_label"],
877889
agg_scores=args_predict["agg_scores"]
878-
)
890+
)
891+
879892
np.testing.assert_allclose(y_ps[:, :, 0], y_toy_mapie[strategy])
880893
np.testing.assert_allclose(
881894
classification_coverage_score(y_toy, y_ps[:, :, 0]),

0 commit comments

Comments
 (0)