33import numpy as np
44import pytest
55from sklearn .cluster import KMeans
6- from sklearn .datasets import load_iris , make_classification
6+ from sklearn .datasets import load_iris , make_classification , make_regression
77from sklearn .preprocessing import OneHotEncoder
88
99from fastcan import minibatch
10+ from fastcan .utils import ssc
1011
1112
1213def test_data_pruning ():
@@ -60,7 +61,7 @@ def test_select_minibatch_cls():
6061 assert indices .size == n_to_select
6162
6263
63- def test_minibatch_error ():
64+ def test_minibatch_error_warning ():
6465 # Test refine raise error.
6566 n_samples = 200
6667 n_features = 20
@@ -83,3 +84,37 @@ def test_minibatch_error():
8384
8485 with pytest .raises (ValueError , match = r"n_features_to_select .*" ):
8586 _ = minibatch (X , y , n_features + 1 , batch_size = 3 )
87+
88+ Y = OneHotEncoder (sparse_output = False ).fit_transform (y .reshape (- 1 , 1 ))
89+ Y [:, 0 ] = 1
90+ with pytest .warns (
91+ UserWarning , match = r"Contain constant targets, whose indices are .*"
92+ ):
93+ _ = minibatch (X , Y , 5 , batch_size = 3 )
94+
95+
96+ def test_minibatch_ssc_aligned (capsys ):
97+ # Test whether ssc of minibatch aligns with the true ssc score
98+ n_features = 20
99+ n_targets = 5
100+ n_to_select = 10
101+ X , y = make_regression (
102+ n_samples = 100 ,
103+ n_features = n_features ,
104+ n_informative = 10 ,
105+ n_targets = n_targets ,
106+ noise = 0.1 ,
107+ random_state = 0 ,
108+ )
109+
110+ # The last batch of features are selected for the last target.
111+ # The number of features selected per target is n_to_select // n_targets
112+ n_features_per_target = n_to_select // n_targets
113+ indices = minibatch (X , y , n_to_select , batch_size = n_features_per_target + 1 )
114+ captured = capsys .readouterr ()
115+
116+ gtruth_ssc = ssc (X [:, indices [- n_features_per_target :]], y [:, [- 1 ]])
117+ assert (
118+ f"Progress: { n_to_select } /{ n_to_select } , "
119+ f"Batch SSC: { gtruth_ssc :.5f} " in captured .out
120+ )
0 commit comments