Skip to content

Commit a4c79e8

Browse files
Refactor parametrized tests (#99)
1 parent 86a1296 commit a4c79e8

File tree

4 files changed

+121
-297
lines changed

4 files changed

+121
-297
lines changed

tests/test_LocalClassifierPerLevel.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77
from scipy.sparse import csr_matrix
88
from sklearn.exceptions import NotFittedError
99
from sklearn.linear_model import LogisticRegression
10-
from sklearn.neighbors import KNeighborsClassifier
1110
from sklearn.utils.estimator_checks import parametrize_with_checks
1211
from sklearn.utils.validation import check_is_fitted
13-
1412
from hiclass import LocalClassifierPerLevel
15-
from hiclass.ConstantClassifier import ConstantClassifier
1613

1714

1815
@parametrize_with_checks([LocalClassifierPerLevel()])
@@ -79,16 +76,6 @@ def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
7976
assert 1
8077

8178

82-
def test_fit_1_class():
83-
lcpl = LocalClassifierPerLevel(local_classifier=LogisticRegression(), n_jobs=2)
84-
y = np.array([["1", "2"]])
85-
X = np.array([[1, 2]])
86-
ground_truth = np.array([["1", "2"]])
87-
lcpl.fit(X, y)
88-
prediction = lcpl.predict(X)
89-
assert_array_equal(ground_truth, prediction)
90-
91-
9279
@pytest.fixture
9380
def fitted_logistic_regression():
9481
digraph = LocalClassifierPerLevel(local_classifier=LogisticRegression())
@@ -146,89 +133,3 @@ def test_fit_predict():
146133
pytest.fail(repr(e))
147134
predictions = lcpl.predict(x)
148135
assert_array_equal(y, predictions)
149-
150-
151-
@pytest.fixture
152-
def empty_levels():
153-
X = [
154-
[1],
155-
[2],
156-
[3],
157-
]
158-
y = np.array(
159-
[
160-
["1"],
161-
["2", "2.1"],
162-
["3", "3.1", "3.1.2"],
163-
],
164-
dtype=object,
165-
)
166-
return X, y
167-
168-
169-
def test_empty_levels(empty_levels):
170-
lcppn = LocalClassifierPerLevel()
171-
X, y = empty_levels
172-
lcppn.fit(X, y)
173-
predictions = lcppn.predict(X)
174-
ground_truth = [
175-
["1", "", ""],
176-
["2", "2.1", ""],
177-
["3", "3.1", "3.1.2"],
178-
]
179-
assert list(lcppn.hierarchy_.nodes) == [
180-
"1",
181-
"2",
182-
"2" + lcppn.separator_ + "2.1",
183-
"3",
184-
"3" + lcppn.separator_ + "3.1",
185-
"3" + lcppn.separator_ + "3.1" + lcppn.separator_ + "3.1.2",
186-
lcppn.root_,
187-
]
188-
assert_array_equal(ground_truth, predictions)
189-
190-
191-
def test_fit_bert():
192-
bert = ConstantClassifier()
193-
lcpl = LocalClassifierPerLevel(
194-
local_classifier=bert,
195-
bert=True,
196-
)
197-
X = ["Text 1", "Text 2"]
198-
y = ["a", "a"]
199-
lcpl.fit(X, y)
200-
check_is_fitted(lcpl)
201-
predictions = lcpl.predict(X)
202-
assert_array_equal(y, predictions)
203-
204-
205-
def test_knn():
206-
knn = KNeighborsClassifier(
207-
n_neighbors=2,
208-
)
209-
lcpl = LocalClassifierPerLevel(
210-
local_classifier=knn,
211-
)
212-
y = np.array([["a", "b"], ["a", "c"]])
213-
X = np.array([[1, 2], [3, 4]])
214-
lcpl.fit(X, y)
215-
check_is_fitted(lcpl)
216-
# predictions = lcpl.predict(X)
217-
# assert_array_equal(y, predictions)
218-
219-
220-
def test_fit_multiple_dim_input():
221-
lcpl = LocalClassifierPerLevel()
222-
X = np.random.rand(1, 275, 3)
223-
y = np.array([["a", "b", "c"]])
224-
lcpl.fit(X, y)
225-
check_is_fitted(lcpl)
226-
227-
228-
def test_predict_multiple_dim_input():
229-
lcpl = LocalClassifierPerLevel()
230-
X = np.random.rand(1, 275, 3)
231-
y = np.array([["a", "b", "c"]])
232-
lcpl.fit(X, y)
233-
predictions = lcpl.predict(X)
234-
assert predictions is not None

tests/test_LocalClassifierPerNode.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77
from scipy.sparse import csr_matrix
88
from sklearn.exceptions import NotFittedError
99
from sklearn.linear_model import LogisticRegression
10-
from sklearn.neighbors import KNeighborsClassifier
1110
from sklearn.utils.estimator_checks import parametrize_with_checks
1211
from sklearn.utils.validation import check_is_fitted
1312

1413
from hiclass import LocalClassifierPerNode
1514
from hiclass.BinaryPolicy import ExclusivePolicy
16-
from hiclass.ConstantClassifier import ConstantClassifier
1715

1816

1917
@parametrize_with_checks([LocalClassifierPerNode()])
@@ -135,16 +133,6 @@ def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
135133
assert 1
136134

137135

138-
def test_fit_1_class():
139-
lcpn = LocalClassifierPerNode(local_classifier=LogisticRegression(), n_jobs=2)
140-
y = np.array([["1", "2"]])
141-
X = np.array([[1, 2]])
142-
ground_truth = np.array([["1", "2"]])
143-
lcpn.fit(X, y)
144-
prediction = lcpn.predict(X)
145-
assert_array_equal(ground_truth, prediction)
146-
147-
148136
def test_clean_up(digraph_logistic_regression):
149137
digraph_logistic_regression._clean_up()
150138
with pytest.raises(AttributeError):
@@ -207,89 +195,3 @@ def test_fit_predict():
207195
lcpn.fit(x, y)
208196
predictions = lcpn.predict(x)
209197
assert_array_equal(y, predictions)
210-
211-
212-
@pytest.fixture
213-
def empty_levels():
214-
X = [
215-
[1],
216-
[2],
217-
[3],
218-
]
219-
y = np.array(
220-
[
221-
["1"],
222-
["2", "2.1"],
223-
["3", "3.1", "3.1.2"],
224-
],
225-
dtype=object,
226-
)
227-
return X, y
228-
229-
230-
def test_empty_levels(empty_levels):
231-
lcppn = LocalClassifierPerNode()
232-
X, y = empty_levels
233-
lcppn.fit(X, y)
234-
predictions = lcppn.predict(X)
235-
ground_truth = [
236-
["1", "", ""],
237-
["2", "2.1", ""],
238-
["3", "3.1", "3.1.2"],
239-
]
240-
assert list(lcppn.hierarchy_.nodes) == [
241-
"1",
242-
"2",
243-
"2" + lcppn.separator_ + "2.1",
244-
"3",
245-
"3" + lcppn.separator_ + "3.1",
246-
"3" + lcppn.separator_ + "3.1" + lcppn.separator_ + "3.1.2",
247-
lcppn.root_,
248-
]
249-
assert_array_equal(ground_truth, predictions)
250-
251-
252-
def test_fit_bert():
253-
bert = ConstantClassifier()
254-
lcpn = LocalClassifierPerNode(
255-
local_classifier=bert,
256-
bert=True,
257-
)
258-
X = ["Text 1", "Text 2"]
259-
y = ["a", "a"]
260-
lcpn.fit(X, y)
261-
check_is_fitted(lcpn)
262-
predictions = lcpn.predict(X)
263-
assert_array_equal(y, predictions)
264-
265-
266-
def test_knn():
267-
knn = KNeighborsClassifier(
268-
n_neighbors=2,
269-
)
270-
lcpn = LocalClassifierPerNode(
271-
local_classifier=knn,
272-
)
273-
y = np.array([["a", "b"], ["a", "c"]])
274-
X = np.array([[1, 2], [3, 4]])
275-
lcpn.fit(X, y)
276-
check_is_fitted(lcpn)
277-
# predictions = lcpn.predict(X)
278-
# assert_array_equal(y, predictions)
279-
280-
281-
def test_fit_multiple_dim_input():
282-
lcpn = LocalClassifierPerNode()
283-
X = np.random.rand(1, 275, 3)
284-
y = np.array([["a", "b", "c"]])
285-
lcpn.fit(X, y)
286-
check_is_fitted(lcpn)
287-
288-
289-
def test_predict_multiple_dim_input():
290-
lcpn = LocalClassifierPerNode()
291-
X = np.random.rand(1, 275, 3)
292-
y = np.array([["a", "b", "c"]])
293-
lcpn.fit(X, y)
294-
predictions = lcpn.predict(X)
295-
assert predictions is not None

tests/test_LocalClassifierPerParentNode.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
from scipy.sparse import csr_matrix
99
from sklearn.exceptions import NotFittedError
1010
from sklearn.linear_model import LogisticRegression
11-
from sklearn.neighbors import KNeighborsClassifier
1211
from sklearn.utils.estimator_checks import parametrize_with_checks
1312
from sklearn.utils.validation import check_is_fitted
1413

1514
from hiclass import LocalClassifierPerParentNode
16-
from hiclass.ConstantClassifier import ConstantClassifier
1715

1816

1917
@parametrize_with_checks([LocalClassifierPerParentNode()])
@@ -88,18 +86,6 @@ def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
8886
assert 1
8987

9088

91-
def test_fit_1_class():
92-
lcppn = LocalClassifierPerParentNode(
93-
local_classifier=LogisticRegression(), n_jobs=2
94-
)
95-
y = np.array([["1", "2"]])
96-
X = np.array([[1, 2]])
97-
ground_truth = np.array([["1", "2"]])
98-
lcppn.fit(X, y)
99-
prediction = lcppn.predict(X)
100-
assert_array_equal(ground_truth, prediction)
101-
102-
10389
@pytest.fixture
10490
def digraph_2d():
10591
classifier = LocalClassifierPerParentNode()
@@ -198,89 +184,3 @@ def test_fit_predict():
198184
lcppn.fit(x, y)
199185
predictions = lcppn.predict(x)
200186
assert_array_equal(y, predictions)
201-
202-
203-
@pytest.fixture
204-
def empty_levels():
205-
X = [
206-
[1],
207-
[2],
208-
[3],
209-
]
210-
y = np.array(
211-
[
212-
["1"],
213-
["2", "2.1"],
214-
["3", "3.1", "3.1.2"],
215-
],
216-
dtype=object,
217-
)
218-
return X, y
219-
220-
221-
def test_empty_levels(empty_levels):
222-
lcppn = LocalClassifierPerParentNode()
223-
X, y = empty_levels
224-
lcppn.fit(X, y)
225-
predictions = lcppn.predict(X)
226-
ground_truth = [
227-
["1", "", ""],
228-
["2", "2.1", ""],
229-
["3", "3.1", "3.1.2"],
230-
]
231-
assert list(lcppn.hierarchy_.nodes) == [
232-
"1",
233-
"2",
234-
"2" + lcppn.separator_ + "2.1",
235-
"3",
236-
"3" + lcppn.separator_ + "3.1",
237-
"3" + lcppn.separator_ + "3.1" + lcppn.separator_ + "3.1.2",
238-
lcppn.root_,
239-
]
240-
assert_array_equal(ground_truth, predictions)
241-
242-
243-
def test_bert():
244-
bert = ConstantClassifier()
245-
lcpn = LocalClassifierPerParentNode(
246-
local_classifier=bert,
247-
bert=True,
248-
)
249-
X = ["Text 1", "Text 2"]
250-
y = ["a", "a"]
251-
lcpn.fit(X, y)
252-
check_is_fitted(lcpn)
253-
predictions = lcpn.predict(X)
254-
assert_array_equal(y, predictions)
255-
256-
257-
def test_knn():
258-
knn = KNeighborsClassifier(
259-
n_neighbors=2,
260-
)
261-
lcppn = LocalClassifierPerParentNode(
262-
local_classifier=knn,
263-
)
264-
y = np.array([["a", "b"], ["a", "c"]])
265-
X = np.array([[1, 2], [3, 4]])
266-
lcppn.fit(X, y)
267-
check_is_fitted(lcppn)
268-
# predictions = lcppn.predict(X)
269-
# assert_array_equal(y, predictions)
270-
271-
272-
def test_fit_multiple_dim_input():
273-
lcppn = LocalClassifierPerParentNode()
274-
X = np.random.rand(1, 275, 3)
275-
y = np.array([["a", "b", "c"]])
276-
lcppn.fit(X, y)
277-
check_is_fitted(lcppn)
278-
279-
280-
def test_predict_multiple_dim_input():
281-
lcppn = LocalClassifierPerParentNode()
282-
X = np.random.rand(1, 275, 3)
283-
y = np.array([["a", "b", "c"]])
284-
lcppn.fit(X, y)
285-
predictions = lcppn.predict(X)
286-
assert predictions is not None

0 commit comments

Comments
 (0)