Skip to content

Commit c20262c

Browse files
author
Christian Jorgensen
committed
New tests
1 parent c308bf2 commit c20262c

File tree

4 files changed

+105
-29
lines changed

4 files changed

+105
-29
lines changed

src/skmatter/decomposition/_kernel_pcovc.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,26 +298,37 @@ def fit(self, X, Y, W=None):
298298
self.centerer_ = KernelNormalizer()
299299
K = self.centerer_.fit_transform(K)
300300

301-
compatible_classifiers = (
301+
compatible_clfs = (
302302
LogisticRegression,
303303
LogisticRegressionCV,
304304
LinearSVC,
305305
LinearDiscriminantAnalysis,
306-
MultiOutputClassifier,
307306
RidgeClassifier,
308307
RidgeClassifierCV,
309308
SGDClassifier,
310309
Perceptron,
310+
MultiOutputClassifier,
311311
)
312312

313-
if self.classifier not in ["precomputed", None] and not isinstance(
314-
self.classifier, compatible_classifiers
315-
):
316-
raise ValueError(
317-
"Classifier must be an instance of `"
318-
f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`"
319-
", or `precomputed`"
320-
)
313+
if self.classifier not in ["precomputed", None]:
314+
if not isinstance(self.classifier, compatible_clfs):
315+
raise ValueError(
316+
"Classifier must be an instance of `"
317+
f"{'`, `'.join(c.__name__ for c in compatible_clfs)}`"
318+
", or `precomputed`."
319+
)
320+
321+
if isinstance(self.classifier, MultiOutputClassifier):
322+
if not isinstance(self.classifier.estimator, compatible_clfs):
323+
name = type(self.classifier.estimator).__name__
324+
raise ValueError(
325+
"The instance of MultiOutputClassifier passed as the "
326+
f"KernelPCovC classifier contains `{name}`, "
327+
"which is not supported. The MultiOutputClassifier "
328+
"must contain an instance of `"
329+
f"{'`, `'.join(c.__name__ for c in compatible_clfs[:-1])}"
330+
"`, or `precomputed`."
331+
)
321332

322333
multioutput = self.n_outputs_ != 1
323334
precomputed = self.classifier == "precomputed"

src/skmatter/decomposition/_pcovc.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,26 +278,37 @@ def fit(self, X, Y, W=None):
278278

279279
super()._set_fit_params(X)
280280

281-
compatible_classifiers = (
281+
compatible_clfs = (
282282
LogisticRegression,
283283
LogisticRegressionCV,
284284
LinearSVC,
285285
LinearDiscriminantAnalysis,
286-
MultiOutputClassifier,
287286
RidgeClassifier,
288287
RidgeClassifierCV,
289288
SGDClassifier,
290289
Perceptron,
290+
MultiOutputClassifier,
291291
)
292292

293-
if self.classifier not in ["precomputed", None] and not isinstance(
294-
self.classifier, compatible_classifiers
295-
):
296-
raise ValueError(
297-
"Classifier must be an instance of `"
298-
f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`"
299-
", or `precomputed`"
300-
)
293+
if self.classifier not in ["precomputed", None]:
294+
if not isinstance(self.classifier, compatible_clfs):
295+
raise ValueError(
296+
"Classifier must be an instance of `"
297+
f"{'`, `'.join(c.__name__ for c in compatible_clfs)}`"
298+
", or `precomputed`."
299+
)
300+
301+
if isinstance(self.classifier, MultiOutputClassifier):
302+
if not isinstance(self.classifier.estimator, compatible_clfs):
303+
name = type(self.classifier.estimator).__name__
304+
raise ValueError(
305+
"The instance of MultiOutputClassifier passed as the "
306+
f"PCovC classifier contains `{name}`, "
307+
"which is not supported. The MultiOutputClassifier "
308+
"must contain an instance of `"
309+
f"{'`, `'.join(c.__name__ for c in compatible_clfs[:-1])}"
310+
"`, or `precomputed`."
311+
)
301312

302313
multioutput = self.n_outputs_ != 1
303314
precomputed = self.classifier == "precomputed"

tests/test_kernel_pcovc.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,9 @@ def test_incompatible_classifier(self):
260260
str(cm.exception),
261261
"Classifier must be an instance of "
262262
"`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, "
263-
"`LinearDiscriminantAnalysis`, `MultiOutputClassifier`, `RidgeClassifier`, "
264-
"`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, "
265-
"or `precomputed`",
263+
"`LinearDiscriminantAnalysis`, `RidgeClassifier`, `RidgeClassifierCV`, "
264+
"`SGDClassifier`, `Perceptron`, `MultiOutputClassifier`, "
265+
"or `precomputed`.",
266266
)
267267

268268
def test_none_classifier(self):
@@ -590,7 +590,33 @@ def test_decision_function_multioutput(self):
590590
T = kpcovc.transform(self.X)
591591
_ = kpcovc.decision_function(T=T)
592592

593-
# TODO: Add tests for addition of score function to pcovc.py
593+
def test_score(self):
594+
"""Check that KernelPCovC's score behaves properly with multiple labels."""
595+
kpcovc_multi = self.model(
596+
classifier=MultiOutputClassifier(estimator=LogisticRegression())
597+
)
598+
kpcovc_multi.fit(self.X, np.column_stack((self.Y, self.Y)))
599+
score_multi = kpcovc_multi.score(self.X, np.column_stack((self.Y, self.Y)))
600+
601+
kpcovc_single = self.model().fit(self.X, self.Y)
602+
score_single = kpcovc_single.score(self.X, self.Y)
603+
self.assertEqual(score_single, score_multi)
604+
605+
def test_bad_multioutput_estimator(self):
606+
"""Check that KernelPCovC returns an error when a MultiOutputClassifier
607+
is improperly constructed.
608+
"""
609+
with self.assertRaises(ValueError) as cm:
610+
pcovc = self.model(classifier=MultiOutputClassifier(estimator=GaussianNB()))
611+
pcovc.fit(self.X, np.column_stack((self.Y, self.Y)))
612+
self.assertEqual(
613+
str(cm.exception),
614+
"The instance of MultiOutputClassifier passed as the KernelPCovC classifier"
615+
" contains `GaussianNB`, which is not supported. The MultiOutputClassifier "
616+
"must contain an instance of `LogisticRegression`, `LogisticRegressionCV`, "
617+
"`LinearSVC`, `LinearDiscriminantAnalysis`, `RidgeClassifier`, "
618+
"`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, or `precomputed`.",
619+
)
594620

595621

596622
if __name__ == "__main__":

tests/test_pcovc.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,9 @@ def test_incompatible_classifier(self):
536536
str(cm.exception),
537537
"Classifier must be an instance of "
538538
"`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, "
539-
"`LinearDiscriminantAnalysis`, `MultiOutputClassifier`, `RidgeClassifier`, "
540-
"`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, "
541-
"or `precomputed`",
539+
"`LinearDiscriminantAnalysis`, `RidgeClassifier`, `RidgeClassifierCV`, "
540+
"`SGDClassifier`, `Perceptron`, `MultiOutputClassifier`, "
541+
"or `precomputed`.",
542542
)
543543

544544
def test_none_classifier(self):
@@ -648,7 +648,9 @@ def test_decision_function_multioutput(self):
648648
"""Check that PCovC's decision_function works in edge
649649
cases when `n_outputs_ > 1`.
650650
"""
651-
pcovc = self.model(classifier=MultiOutputClassifier(estimator=LinearSVC()))
651+
pcovc = self.model(
652+
classifier=MultiOutputClassifier(estimator=LogisticRegression())
653+
)
652654
pcovc.fit(self.X, np.column_stack((self.Y, self.Y)))
653655
with self.assertRaises(ValueError) as cm:
654656
_ = pcovc.decision_function()
@@ -660,7 +662,33 @@ def test_decision_function_multioutput(self):
660662
T = pcovc.transform(self.X)
661663
_ = pcovc.decision_function(T=T)
662664

663-
# TODO: Add tests for addition of score function to pcovc.py
665+
def test_score(self):
666+
"""Check that PCovC's score behaves properly with multiple labels."""
667+
pcovc_multi = self.model(
668+
classifier=MultiOutputClassifier(estimator=LogisticRegression())
669+
)
670+
pcovc_multi.fit(self.X, np.column_stack((self.Y, self.Y)))
671+
score_multi = pcovc_multi.score(self.X, np.column_stack((self.Y, self.Y)))
672+
673+
pcovc_single = self.model().fit(self.X, self.Y)
674+
score_single = pcovc_single.score(self.X, self.Y)
675+
self.assertEqual(score_single, score_multi)
676+
677+
def test_bad_multioutput_estimator(self):
678+
"""Check that PCovC returns an error when a MultiOutputClassifier
679+
is improperly constructed.
680+
"""
681+
with self.assertRaises(ValueError) as cm:
682+
pcovc = self.model(classifier=MultiOutputClassifier(estimator=GaussianNB()))
683+
pcovc.fit(self.X, np.column_stack((self.Y, self.Y)))
684+
self.assertEqual(
685+
str(cm.exception),
686+
"The instance of MultiOutputClassifier passed as the PCovC classifier "
687+
"contains `GaussianNB`, which is not supported. The MultiOutputClassifier "
688+
"must contain an instance of `LogisticRegression`, `LogisticRegressionCV`, "
689+
"`LinearSVC`, `LinearDiscriminantAnalysis`, `RidgeClassifier`, "
690+
"`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, or `precomputed`.",
691+
)
664692

665693

666694
if __name__ == "__main__":

0 commit comments

Comments
 (0)