Skip to content

Commit bdfdb24

Browse files
authored
Assigned n_basis to SCML when needed for tests. Catch warn when needed as well. (#341)
1 parent 6a4aaea commit bdfdb24

File tree

5 files changed

+34
-35
lines changed

5 files changed

+34
-35
lines changed

test/metric_learn_test.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_big_n_features(self):
9595
n_informative=60, n_redundant=0, n_repeated=0,
9696
random_state=42)
9797
X = StandardScaler().fit_transform(X)
98-
scml = SCML_Supervised(random_state=42)
98+
scml = SCML_Supervised(random_state=42, n_basis=399)
9999
scml.fit(X, y)
100100
csep = class_separation(scml.transform(X), y)
101101
assert csep < 0.7
@@ -106,7 +106,7 @@ def test_big_n_features(self):
106106
[2, 0], [2, 1]]),
107107
np.array([1, 0, 1, 0])))])
108108
def test_bad_basis(self, estimator, data):
109-
model = estimator(basis='bad_basis')
109+
model = estimator(basis='bad_basis', n_basis=33) # n_basis doesn't matter
110110
msg = ("`basis` must be one of the options '{}' or an array of shape "
111111
"(n_basis, n_features)."
112112
.format("', '".join(model._authorized_basis)))
@@ -238,16 +238,23 @@ def test_lda_toy(self):
238238
@pytest.mark.parametrize('n_features', [10, 50, 100])
239239
@pytest.mark.parametrize('n_classes', [5, 10, 15])
240240
def test_triplet_diffs(self, n_samples, n_features, n_classes):
241+
"""
242+
Test that the correct value of n_basis is being generated with
243+
different triplet constraints.
244+
"""
241245
X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
242246
n_features=n_features, n_informative=n_features,
243247
n_redundant=0, n_repeated=0)
244248
X = StandardScaler().fit_transform(X)
245-
246-
model = SCML_Supervised()
249+
model = SCML_Supervised(n_basis=None) # Explicit n_basis=None
247250
constraints = Constraints(y)
248251
triplets = constraints.generate_knntriplets(X, model.k_genuine,
249252
model.k_impostor)
250-
basis, n_basis = model._generate_bases_dist_diff(triplets, X)
253+
254+
msg = "As no value for `n_basis` was selected, "
255+
with pytest.warns(UserWarning) as raised_warning:
256+
basis, n_basis = model._generate_bases_dist_diff(triplets, X)
257+
assert msg in str(raised_warning[0].message)
251258

252259
expected_n_basis = n_features * 80
253260
assert n_basis == expected_n_basis
@@ -257,13 +264,21 @@ def test_triplet_diffs(self, n_samples, n_features, n_classes):
257264
@pytest.mark.parametrize('n_features', [10, 50, 100])
258265
@pytest.mark.parametrize('n_classes', [5, 10, 15])
259266
def test_lda(self, n_samples, n_features, n_classes):
267+
"""
268+
Test that when n_basis=None, the correct n_basis is generated,
269+
for SCML_Supervised and different values of n_samples, n_features
270+
and n_classes.
271+
"""
260272
X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
261273
n_features=n_features, n_informative=n_features,
262274
n_redundant=0, n_repeated=0)
263275
X = StandardScaler().fit_transform(X)
264276

265-
model = SCML_Supervised()
266-
basis, n_basis = model._generate_bases_LDA(X, y)
277+
msg = "As no value for `n_basis` was selected, "
278+
with pytest.warns(UserWarning) as raised_warning:
279+
model = SCML_Supervised(n_basis=None) # Explicit n_basis=None
280+
basis, n_basis = model._generate_bases_LDA(X, y)
281+
assert msg in str(raised_warning[0].message)
267282

268283
num_eig = min(n_classes - 1, n_features)
269284
expected_n_basis = min(20 * n_features, n_samples * 2 * num_eig - 1)
@@ -299,7 +314,7 @@ def test_int_inputs_supervised(self, name):
299314
assert msg == raised_error.value.args[0]
300315

301316
def test_large_output_iter(self):
302-
scml = SCML(max_iter=1, output_iter=2)
317+
scml = SCML(max_iter=1, output_iter=2, n_basis=33) # n_basis don't matter
303318
triplets = np.array([[[0, 1], [2, 1], [0, 0]]])
304319
msg = ("The value of output_iter must be equal or smaller than"
305320
" max_iter.")

test/test_mahalanobis_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ def test_components_is_2D(estimator, build_dataset):
291291
model.fit(*remove_y(estimator, input_data, labels))
292292
assert model.components_.shape == (X.shape[1], X.shape[1])
293293

294-
# test that it works for 1 feature
295-
trunc_data = input_data[..., :1]
294+
# test that it works for 1 feature. Use 2nd dimention, to avoid border cases
295+
trunc_data = input_data[..., 1:2]
296296
# we drop duplicates that might have been formed, i.e. of the form
297297
# aabc or abcc or aabb for quadruplets, and aa for pairs.
298298

test/test_sklearn_compat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def test_rca(self):
7979
check_estimator(Stable_RCA_Supervised())
8080

8181
def test_scml(self):
82-
check_estimator(SCML_Supervised())
82+
msg = "As no value for `n_basis` was selected, "
83+
with pytest.warns(UserWarning) as raised_warning:
84+
check_estimator(SCML_Supervised())
85+
assert msg in str(raised_warning[0].message)
8386

8487

8588
RNG = check_random_state(0)

test/test_triplets_classifiers.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
from sklearn.exceptions import NotFittedError
33
from sklearn.model_selection import train_test_split
4-
import metric_learn
54

65
from test.test_utils import triplets_learners, ids_triplets_learners
76
from metric_learn.sklearn_shims import set_random_state
@@ -21,13 +20,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
2120
estimator.set_params(preprocessor=preprocessor)
2221
set_random_state(estimator)
2322
triplets_train, triplets_test = train_test_split(input_data)
24-
if isinstance(estimator, metric_learn.SCML):
25-
msg = "As no value for `n_basis` was selected, "
26-
with pytest.warns(UserWarning) as raised_warning:
27-
estimator.fit(triplets_train)
28-
assert msg in str(raised_warning[0].message)
29-
else:
30-
estimator.fit(triplets_train)
23+
estimator.fit(triplets_train)
3124
predictions = estimator.predict(triplets_test)
3225

3326
not_valid = [e for e in predictions if e not in [-1, 1]]
@@ -49,13 +42,7 @@ def test_no_zero_prediction(estimator, build_dataset):
4942
# Dummy fit
5043
estimator = clone(estimator)
5144
set_random_state(estimator)
52-
if isinstance(estimator, metric_learn.SCML):
53-
msg = "As no value for `n_basis` was selected, "
54-
with pytest.warns(UserWarning) as raised_warning:
55-
estimator.fit(triplets)
56-
assert msg in str(raised_warning[0].message)
57-
else:
58-
estimator.fit(triplets)
45+
estimator.fit(triplets)
5946
# We force the transformation to be identity, to force euclidean distance
6047
estimator.components_ = np.eye(X.shape[1])
6148

@@ -106,13 +93,7 @@ def test_accuracy_toy_example(estimator, build_dataset):
10693
triplets, _, _, X = build_dataset(with_preprocessor=False)
10794
estimator = clone(estimator)
10895
set_random_state(estimator)
109-
if isinstance(estimator, metric_learn.SCML):
110-
msg = "As no value for `n_basis` was selected, "
111-
with pytest.warns(UserWarning) as raised_warning:
112-
estimator.fit(triplets)
113-
assert msg in str(raised_warning[0].message)
114-
else:
115-
estimator.fit(triplets)
96+
estimator.fit(triplets)
11697
# We take the two first points and we build 4 regularly spaced points on the
11798
# line they define, so that it's easy to build triplets of different
11899
# similarities.

test/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def build_quadruplets(with_preprocessor=False):
117117
[learner for (learner, _) in
118118
quadruplets_learners]))
119119

120-
triplets_learners = [(SCML(), build_triplets)]
120+
triplets_learners = [(SCML(n_basis=320), build_triplets)]
121121
ids_triplets_learners = list(map(lambda x: x.__class__.__name__,
122122
[learner for (learner, _) in
123123
triplets_learners]))
@@ -140,7 +140,7 @@ def build_quadruplets(with_preprocessor=False):
140140
(RCA_Supervised(num_chunks=5), build_classification),
141141
(SDML_Supervised(prior='identity', balance_param=1e-5),
142142
build_classification),
143-
(SCML_Supervised(), build_classification)]
143+
(SCML_Supervised(n_basis=80), build_classification)]
144144
ids_classifiers = list(map(lambda x: x.__class__.__name__,
145145
[learner for (learner, _) in
146146
classifiers]))

0 commit comments

Comments
 (0)