Skip to content

Commit 909e8d4

Browse files
committed
Fix tests bert
1 parent 9b5ffea commit 909e8d4

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

tests/test_LocalClassifierPerParentNode.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import networkx as nx
55
import numpy as np
66
import pytest
7-
from numpy.testing import assert_array_equal, assert_array_almost_equal
7+
from bert_sklearn import BertClassifier
8+
from numpy.testing import assert_array_almost_equal, assert_array_equal
89
from scipy.sparse import csr_matrix
910
from sklearn.exceptions import NotFittedError
1011
from sklearn.linear_model import LogisticRegression
1112
from sklearn.utils.estimator_checks import parametrize_with_checks
1213
from sklearn.utils.validation import check_is_fitted
14+
1315
from hiclass import LocalClassifierPerParentNode
1416
from hiclass._calibration.Calibrator import _Calibrator
1517
from hiclass.HierarchicalClassifier import make_leveled
@@ -393,3 +395,37 @@ def test_fit_calibrate_predict_predict_proba_bert():
393395
classifier.calibrate(x, y)
394396
classifier.predict(x)
395397
classifier.predict_proba(x)
398+
399+
400+
# Note: bert only works with the local classifier per parent node
401+
# It does not have the attribute classes_, which are necessary
402+
# for the local classifiers per level and per node
403+
def test_fit_bert():
404+
bert = BertClassifier()
405+
clf = LocalClassifierPerParentNode(
406+
local_classifier=bert,
407+
bert=True,
408+
)
409+
x = ["Batman", "rorschach"]
410+
y = [
411+
["Action", "The Dark Night"],
412+
["Action", "Watchmen"],
413+
]
414+
clf.fit(x, y)
415+
check_is_fitted(clf)
416+
predictions = clf.predict(x)
417+
assert_array_equal(y, predictions)
418+
419+
420+
def test_bert_unleveled():
421+
clf = LocalClassifierPerParentNode(
422+
local_classifier=BertClassifier(),
423+
bert=True,
424+
)
425+
x = ["Batman", "Jaws"]
426+
y = [["Action", "The Dark Night"], ["Thriller"]]
427+
ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]]
428+
clf.fit(x, y)
429+
check_is_fitted(clf)
430+
predictions = clf.predict(x)
431+
assert_array_equal(ground_truth, predictions)

tests/test_LocalClassifiers.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import pytest
6-
from bert_sklearn import BertClassifier
76
from numpy.testing import assert_array_equal
87
from pyfakefs.fake_filesystem_unittest import Patcher
98
from sklearn.linear_model import LogisticRegression
@@ -76,25 +75,6 @@ def test_empty_levels(empty_levels, classifier):
7675
assert_array_equal(ground_truth, predictions)
7776

7877

79-
@pytest.mark.parametrize("classifier", classifiers)
80-
def test_fit_bert(classifier):
81-
bert = BertClassifier()
82-
clf = classifier(
83-
local_classifier=bert,
84-
bert=True,
85-
)
86-
x = ["Batman", "Joker", "Rorschach"]
87-
y = [
88-
["Action", "The Dark Night"],
89-
["Action", "The Dark Night"],
90-
["Action", "Watchmen"],
91-
]
92-
clf.fit(x, y)
93-
check_is_fitted(clf)
94-
predictions = clf.predict(x)
95-
assert_array_equal(y, predictions)
96-
97-
9878
@pytest.mark.parametrize("classifier", classifiers)
9979
def test_knn(classifier):
10080
knn = KNeighborsClassifier(
@@ -148,18 +128,3 @@ def test_tmp_dir(classifier):
148128
assert expected_name == name
149129
check_is_fitted(classifier)
150130
clf.fit(x, y)
151-
152-
153-
@pytest.mark.parametrize("classifier", classifiers)
154-
def test_bert_unleveled(classifier):
155-
clf = classifier(
156-
local_classifier=BertClassifier(),
157-
bert=True,
158-
)
159-
x = ["Batman", "Joker"]
160-
y = [["Action", "The Dark Night"], ["Action"]]
161-
ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]]
162-
clf.fit(x, y)
163-
check_is_fitted(clf)
164-
predictions = clf.predict(x)
165-
assert_array_equal(ground_truth, predictions)

0 commit comments

Comments
 (0)