@@ -95,7 +95,7 @@ def test_big_n_features(self):
9595 n_informative = 60 , n_redundant = 0 , n_repeated = 0 ,
9696 random_state = 42 )
9797 X = StandardScaler ().fit_transform (X )
98- scml = SCML_Supervised (random_state = 42 )
98+ scml = SCML_Supervised (random_state = 42 , n_basis = 399 )
9999 scml .fit (X , y )
100100 csep = class_separation (scml .transform (X ), y )
101101 assert csep < 0.7
@@ -106,7 +106,7 @@ def test_big_n_features(self):
106106 [2 , 0 ], [2 , 1 ]]),
107107 np .array ([1 , 0 , 1 , 0 ])))])
108108 def test_bad_basis (self , estimator , data ):
109- model = estimator (basis = 'bad_basis' )
109+ model = estimator (basis = 'bad_basis' , n_basis = 33 ) # n_basis doesn't matter
110110 msg = ("`basis` must be one of the options '{}' or an array of shape "
111111 "(n_basis, n_features)."
112112 .format ("', '" .join (model ._authorized_basis )))
@@ -238,16 +238,23 @@ def test_lda_toy(self):
238238 @pytest .mark .parametrize ('n_features' , [10 , 50 , 100 ])
239239 @pytest .mark .parametrize ('n_classes' , [5 , 10 , 15 ])
240240 def test_triplet_diffs (self , n_samples , n_features , n_classes ):
241+ """
242+ Test that the correct value of n_basis is being generated with
243+ different triplet constraints.
244+ """
241245 X , y = make_classification (n_samples = n_samples , n_classes = n_classes ,
242246 n_features = n_features , n_informative = n_features ,
243247 n_redundant = 0 , n_repeated = 0 )
244248 X = StandardScaler ().fit_transform (X )
245-
246- model = SCML_Supervised ()
249+ model = SCML_Supervised (n_basis = None ) # Explicit n_basis=None
247250 constraints = Constraints (y )
248251 triplets = constraints .generate_knntriplets (X , model .k_genuine ,
249252 model .k_impostor )
250- basis , n_basis = model ._generate_bases_dist_diff (triplets , X )
253+
254+ msg = "As no value for `n_basis` was selected, "
255+ with pytest .warns (UserWarning ) as raised_warning :
256+ basis , n_basis = model ._generate_bases_dist_diff (triplets , X )
257+ assert msg in str (raised_warning [0 ].message )
251258
252259 expected_n_basis = n_features * 80
253260 assert n_basis == expected_n_basis
@@ -257,13 +264,21 @@ def test_triplet_diffs(self, n_samples, n_features, n_classes):
257264 @pytest .mark .parametrize ('n_features' , [10 , 50 , 100 ])
258265 @pytest .mark .parametrize ('n_classes' , [5 , 10 , 15 ])
259266 def test_lda (self , n_samples , n_features , n_classes ):
267+ """
268+ Test that when n_basis=None, the correct n_basis is generated,
269+ for SCML_Supervised and different values of n_samples, n_features
270+ and n_classes.
271+ """
260272 X , y = make_classification (n_samples = n_samples , n_classes = n_classes ,
261273 n_features = n_features , n_informative = n_features ,
262274 n_redundant = 0 , n_repeated = 0 )
263275 X = StandardScaler ().fit_transform (X )
264276
265- model = SCML_Supervised ()
266- basis , n_basis = model ._generate_bases_LDA (X , y )
277+ msg = "As no value for `n_basis` was selected, "
278+ with pytest .warns (UserWarning ) as raised_warning :
279+ model = SCML_Supervised (n_basis = None ) # Explicit n_basis=None
280+ basis , n_basis = model ._generate_bases_LDA (X , y )
281+ assert msg in str (raised_warning [0 ].message )
267282
268283 num_eig = min (n_classes - 1 , n_features )
269284 expected_n_basis = min (20 * n_features , n_samples * 2 * num_eig - 1 )
@@ -299,7 +314,7 @@ def test_int_inputs_supervised(self, name):
299314 assert msg == raised_error .value .args [0 ]
300315
301316 def test_large_output_iter (self ):
302- scml = SCML (max_iter = 1 , output_iter = 2 )
317+ scml = SCML (max_iter = 1 , output_iter = 2 , n_basis = 33 ) # n_basis don't matter
303318 triplets = np .array ([[[0 , 1 ], [2 , 1 ], [0 , 0 ]]])
304319 msg = ("The value of output_iter must be equal or smaller than"
305320 " max_iter." )
0 commit comments