Skip to content

Commit a92b6a8

Browse files
committed
Fix unleved bug with bert
1 parent a4a006e commit a92b6a8

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

hiclass/HierarchicalClassifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def _pre_fit(self, X, y, sample_weight):
161161
)
162162
else:
163163
self.X_ = np.array(X)
164-
self.y_ = np.array(make_leveled(y))
164+
self.y_ = check_array(
165+
make_leveled(y), dtype=None, ensure_2d=False, allow_nd=True
166+
)
165167

166168
if sample_weight is not None:
167169
self.sample_weight_ = _check_sample_weight(sample_weight, X)

tests/test_LocalClassifiers.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
import numpy as np
55
import pytest
6+
from bert_sklearn import BertClassifier
67
from numpy.testing import assert_array_equal
78
from pyfakefs.fake_filesystem_unittest import Patcher
89
from sklearn.linear_model import LogisticRegression
910
from sklearn.neighbors import KNeighborsClassifier
1011
from sklearn.utils.validation import check_is_fitted
1112

1213
from hiclass import (
13-
LocalClassifierPerNode,
1414
LocalClassifierPerLevel,
15+
LocalClassifierPerNode,
1516
LocalClassifierPerParentNode,
1617
)
1718
from hiclass.ConstantClassifier import ConstantClassifier
@@ -77,16 +78,20 @@ def test_empty_levels(empty_levels, classifier):
7778

7879
@pytest.mark.parametrize("classifier", classifiers)
7980
def test_fit_bert(classifier):
80-
bert = ConstantClassifier()
81+
bert = BertClassifier()
8182
clf = classifier(
8283
local_classifier=bert,
8384
bert=True,
8485
)
85-
X = ["Text 1", "Text 2"]
86-
y = ["a", "a"]
87-
clf.fit(X, y)
86+
x = ["Batman", "Joker", "Rorschach"]
87+
y = [
88+
["Action", "The Dark Night"],
89+
["Action", "The Dark Night"],
90+
["Action", "Watchmen"],
91+
]
92+
clf.fit(x, y)
8893
check_is_fitted(clf)
89-
predictions = clf.predict(X)
94+
predictions = clf.predict(x)
9095
assert_array_equal(y, predictions)
9196

9297

@@ -148,9 +153,13 @@ def test_tmp_dir(classifier):
148153
@pytest.mark.parametrize("classifier", classifiers)
149154
def test_bert_unleveled(classifier):
150155
clf = classifier(
151-
local_classifier=LogisticRegression(),
156+
local_classifier=BertClassifier(),
152157
bert=True,
153158
)
154-
x = [[0, 1], [2, 3]]
155-
y = [["a"], ["b", "c"]]
159+
x = ["Batman", "Joker"]
160+
y = [["Action", "The Dark Night"], ["Action"]]
161+
ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]]
156162
clf.fit(x, y)
163+
check_is_fitted(clf)
164+
predictions = clf.predict(x)
165+
assert_array_equal(ground_truth, predictions)

0 commit comments

Comments
 (0)