Skip to content

Commit a4a006e

Browse files
committed
Fix unleved bug with bert
1 parent a344014 commit a4a006e

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

hiclass/HierarchicalClassifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _pre_fit(self, X, y, sample_weight):
161161
)
162162
else:
163163
self.X_ = np.array(X)
164-
self.y_ = np.array(y)
164+
self.y_ = np.array(make_leveled(y))
165165

166166
if sample_weight is not None:
167167
self.sample_weight_ = _check_sample_weight(sample_weight, X)

tests/test_LocalClassifiers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,14 @@ def test_tmp_dir(classifier):
143143
assert expected_name == name
144144
check_is_fitted(classifier)
145145
clf.fit(x, y)
146+
147+
148+
@pytest.mark.parametrize("classifier", classifiers)
149+
def test_bert_unleveled(classifier):
150+
clf = classifier(
151+
local_classifier=LogisticRegression(),
152+
bert=True,
153+
)
154+
x = [[0, 1], [2, 3]]
155+
y = [["a"], ["b", "c"]]
156+
clf.fit(x, y)

0 commit comments

Comments
 (0)