|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
| 6 | +from bert_sklearn import BertClassifier |
6 | 7 | from numpy.testing import assert_array_equal |
7 | 8 | from pyfakefs.fake_filesystem_unittest import Patcher |
8 | 9 | from sklearn.linear_model import LogisticRegression |
9 | 10 | from sklearn.neighbors import KNeighborsClassifier |
10 | 11 | from sklearn.utils.validation import check_is_fitted |
11 | 12 |
|
12 | 13 | from hiclass import ( |
13 | | - LocalClassifierPerNode, |
14 | 14 | LocalClassifierPerLevel, |
| 15 | + LocalClassifierPerNode, |
15 | 16 | LocalClassifierPerParentNode, |
16 | 17 | ) |
17 | 18 | from hiclass.ConstantClassifier import ConstantClassifier |
@@ -77,16 +78,20 @@ def test_empty_levels(empty_levels, classifier): |
77 | 78 |
|
78 | 79 | @pytest.mark.parametrize("classifier", classifiers) |
79 | 80 | def test_fit_bert(classifier): |
80 | | - bert = ConstantClassifier() |
| 81 | + bert = BertClassifier() |
81 | 82 | clf = classifier( |
82 | 83 | local_classifier=bert, |
83 | 84 | bert=True, |
84 | 85 | ) |
85 | | - X = ["Text 1", "Text 2"] |
86 | | - y = ["a", "a"] |
87 | | - clf.fit(X, y) |
| 86 | + x = ["Batman", "Joker", "Rorschach"] |
| 87 | + y = [ |
| 88 | + ["Action", "The Dark Night"], |
| 89 | + ["Action", "The Dark Night"], |
| 90 | + ["Action", "Watchmen"], |
| 91 | + ] |
| 92 | + clf.fit(x, y) |
88 | 93 | check_is_fitted(clf) |
89 | | - predictions = clf.predict(X) |
| 94 | + predictions = clf.predict(x) |
90 | 95 | assert_array_equal(y, predictions) |
91 | 96 |
|
92 | 97 |
|
@@ -148,9 +153,13 @@ def test_tmp_dir(classifier): |
148 | 153 | @pytest.mark.parametrize("classifier", classifiers) |
149 | 154 | def test_bert_unleveled(classifier): |
150 | 155 | clf = classifier( |
151 | | - local_classifier=LogisticRegression(), |
| 156 | + local_classifier=BertClassifier(), |
152 | 157 | bert=True, |
153 | 158 | ) |
154 | | - x = [[0, 1], [2, 3]] |
155 | | - y = [["a"], ["b", "c"]] |
| 159 | + x = ["Batman", "Joker"] |
| 160 | + y = [["Action", "The Dark Night"], ["Action"]] |
| 161 | + ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]] |
156 | 162 | clf.fit(x, y) |
| 163 | + check_is_fitted(clf) |
| 164 | + predictions = clf.predict(x) |
| 165 | + assert_array_equal(ground_truth, predictions) |
0 commit comments