Skip to content

Commit ddaec0c

Browse files
author
mata
committed
adding tests for the mondrian method.
1 parent 5c9992f commit ddaec0c

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: tests doc build
22

33
lint:
4-
flake8 . --exclude=doc mapieenv
4+
flake8 . --exclude=doc,mapieenv
55

66
type-check:
77
mypy mapie

mapie/tests/test_classification.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@
293293
"naive": 5 / 9,
294294
"top_k": 1,
295295
"raps": 1,
296-
"raps_randomized": 8/9
296+
"raps_randomized": 8/9,
297+
"mondrian": 1
297298
}
298299

299300
X_toy = np.arange(9).reshape(-1, 1)
@@ -476,6 +477,17 @@
476477
[False, True, True],
477478
[False, False, True],
478479
],
480+
"mondrian": [
481+
[True, False, False],
482+
[True, False, False],
483+
[True, True, False],
484+
[True, True, True],
485+
[True, True, True],
486+
[True, True, True],
487+
[False, True, True],
488+
[False, True, True],
489+
[False, False, True],
490+
],
479491
}
480492

481493
REGULARIZATION_PARAMETERS = [
@@ -876,7 +888,8 @@ def test_toy_dataset_predictions(strategy: str) -> None:
876888
alpha=0.5,
877889
include_last_label=args_predict["include_last_label"],
878890
agg_scores=args_predict["agg_scores"]
879-
)
891+
)
892+
880893
np.testing.assert_allclose(y_ps[:, :, 0], y_toy_mapie[strategy])
881894
np.testing.assert_allclose(
882895
classification_coverage_score(y_toy, y_ps[:, :, 0]),

0 commit comments

Comments
 (0)