Skip to content

Commit 86a1296

Browse files
Add support for allowing multi-dimensional inputs (ndim > 2) #minor (#97)
1 parent 0f9f955 commit 86a1296

7 files changed

+55
-4
lines changed

hiclass/HierarchicalClassifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _pre_fit(self, X, y, sample_weight):
136136

137137
if not self.bert:
138138
self.X_, self.y_ = self._validate_data(
139-
X, y, multi_output=True, accept_sparse="csr"
139+
X, y, multi_output=True, accept_sparse="csr", allow_nd=True
140140
)
141141
else:
142142
self.X_ = np.array(X)

hiclass/LocalClassifierPerLevel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def predict(self, X):
140140

141141
# Input validation
142142
if not self.bert:
143-
X = check_array(X, accept_sparse="csr")
143+
X = check_array(X, accept_sparse="csr", allow_nd=True, ensure_2d=False)
144144
else:
145145
X = np.array(X)
146146

hiclass/LocalClassifierPerNode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def predict(self, X):
150150

151151
# Input validation
152152
if not self.bert:
153-
X = check_array(X, accept_sparse="csr")
153+
X = check_array(X, accept_sparse="csr", allow_nd=True, ensure_2d=False)
154154
else:
155155
X = np.array(X)
156156

hiclass/LocalClassifierPerParentNode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def predict(self, X):
133133

134134
# Input validation
135135
if not self.bert:
136-
X = check_array(X, accept_sparse="csr")
136+
X = check_array(X, accept_sparse="csr", allow_nd=True, ensure_2d=False)
137137
else:
138138
X = np.array(X)
139139

tests/test_LocalClassifierPerLevel.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,20 @@ def test_knn():
215215
check_is_fitted(lcpl)
216216
# predictions = lcpl.predict(X)
217217
# assert_array_equal(y, predictions)
218+
219+
220+
def test_fit_multiple_dim_input():
221+
lcpl = LocalClassifierPerLevel()
222+
X = np.random.rand(1, 275, 3)
223+
y = np.array([["a", "b", "c"]])
224+
lcpl.fit(X, y)
225+
check_is_fitted(lcpl)
226+
227+
228+
def test_predict_multiple_dim_input():
229+
lcpl = LocalClassifierPerLevel()
230+
X = np.random.rand(1, 275, 3)
231+
y = np.array([["a", "b", "c"]])
232+
lcpl.fit(X, y)
233+
predictions = lcpl.predict(X)
234+
assert predictions is not None

tests/test_LocalClassifierPerNode.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,20 @@ def test_knn():
276276
check_is_fitted(lcpn)
277277
# predictions = lcpn.predict(X)
278278
# assert_array_equal(y, predictions)
279+
280+
281+
def test_fit_multiple_dim_input():
282+
lcpn = LocalClassifierPerNode()
283+
X = np.random.rand(1, 275, 3)
284+
y = np.array([["a", "b", "c"]])
285+
lcpn.fit(X, y)
286+
check_is_fitted(lcpn)
287+
288+
289+
def test_predict_multiple_dim_input():
290+
lcpn = LocalClassifierPerNode()
291+
X = np.random.rand(1, 275, 3)
292+
y = np.array([["a", "b", "c"]])
293+
lcpn.fit(X, y)
294+
predictions = lcpn.predict(X)
295+
assert predictions is not None

tests/test_LocalClassifierPerParentNode.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,20 @@ def test_knn():
267267
check_is_fitted(lcppn)
268268
# predictions = lcppn.predict(X)
269269
# assert_array_equal(y, predictions)
270+
271+
272+
def test_fit_multiple_dim_input():
273+
lcppn = LocalClassifierPerParentNode()
274+
X = np.random.rand(1, 275, 3)
275+
y = np.array([["a", "b", "c"]])
276+
lcppn.fit(X, y)
277+
check_is_fitted(lcppn)
278+
279+
280+
def test_predict_multiple_dim_input():
281+
lcppn = LocalClassifierPerParentNode()
282+
X = np.random.rand(1, 275, 3)
283+
y = np.array([["a", "b", "c"]])
284+
lcppn.fit(X, y)
285+
predictions = lcppn.predict(X)
286+
assert predictions is not None

0 commit comments

Comments
 (0)