Skip to content

Commit b9c9d8a

Browse files
FIX minibatch ssc score not match gtruth (#200)
1 parent 5c2dbe3 commit b9c9d8a

File tree

5 files changed

+55
-17
lines changed

5 files changed

+55
-17
lines changed

asv_benchmarks/benchmarks/fastcan.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ def setup_cache(self):
3737
else:
3838
eta = False
3939
beam_width = 10
40-
estimator = FastCan(
41-
n_features_to_select=20,
42-
eta=eta,
43-
beam_width=beam_width
44-
)
40+
estimator = FastCan(n_features_to_select=20, eta=eta, beam_width=beam_width)
4541
estimator.fit(X, y)
4642

4743
est_path = get_estimator_path(self, params)

fastcan/_beam.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ def _beam_search(
4545

4646
for i in range(n_features_to_select - n_inclusions):
4747
if i == 0:
48-
mask, X_selected = _prepare_candidates(
49-
X, mask_exclude, indices_include
50-
)
48+
mask, X_selected = _prepare_candidates(X, mask_exclude, indices_include)
5149
if X_selected.shape[1] == 0:
5250
beams_scores = np.sum((X.T @ V) ** 2, axis=1)
5351
beams_scores[mask] = 0
5452
else:
5553
W_selected = orth(X_selected)
5654
selected_score = np.sum((W_selected.T @ V) ** 2)
57-
beams_scores = _mgs_ssc(
58-
X, V, W_selected, mask, selected_score, tol
59-
)
55+
beams_scores = _mgs_ssc(X, V, W_selected, mask, selected_score, tol)
6056
beams_selected_ids = [indices_include for _ in range(beam_width)]
6157
beams_selected_ids, top_k_scores = _select_top_k(
6258
beams_scores[None, :],

fastcan/_minibatch.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# Authors: The fastcan developers
66
# SPDX-License-Identifier: MIT
77

8+
import warnings
89
from numbers import Integral, Real
910

1011
import numpy as np
1112
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1213
from sklearn.utils._param_validation import Interval, validate_params
1314
from sklearn.utils.validation import check_X_y
1415

16+
from ._beam import _safe_normalize
1517
from ._cancorr_fast import _greedy_search # type: ignore[attr-defined]
1618
from ._fastcan import _prepare_search
1719

@@ -101,11 +103,18 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
101103
)
102104
)
103105
X_transformed_ = X - X.mean(0)
104-
y_transformed_ = y - y.mean(0)
106+
y_transformed_, const_mask = _safe_normalize(y - y.mean(0))
107+
if const_mask.any():
108+
warnings.warn(
109+
f"Contain constant targets, whose indices are {np.where(const_mask)[0]}.",
110+
UserWarning,
111+
)
105112
indices_include = np.zeros(0, dtype=int) # just an empty array
106113
indices_select = np.zeros(0, dtype=int)
107114

108115
for i in range(n_outputs):
116+
if const_mask[i]:
117+
continue
109118
y_i = y_transformed_[:, [i]]
110119
n_selected_i = 0
111120
while n_to_select_split[i] > n_selected_i:
@@ -137,7 +146,9 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
137146
n_selected_i += batch_size_temp
138147
if verbose == 1:
139148
print(
140-
f"Progress: {indices_select.size}/{n_features_to_select}", end="\r"
149+
f"Progress: {indices_select.size}/{n_features_to_select}, "
150+
f"Batch SSC: {scores.sum():.5f}",
151+
end="\r",
141152
)
142153
if verbose == 1:
143154
print()

pixi.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ scikit-learn = ">=1.6.0"
8484
fastcan = { path = ".", editable = true }
8585

8686
[tasks]
87-
time-h = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(3000, 100); y = np.random.rand(3000, 20)' 's = FastCan(100, verbose=0).fit(X, y)'"
88-
time-eta = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(3000, 100); y = np.random.rand(3000, 20)' 's = FastCan(100, eta=True, verbose=0).fit(X, y)'"
87+
time-h = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(30000, 100); y = np.random.rand(30000, 20)' 's = FastCan(100, verbose=0).fit(X, y)'"
88+
time-eta = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(30000, 100); y = np.random.rand(30000, 20)' 's = FastCan(100, eta=True, verbose=0).fit(X, y)'"
8989
profile-minibatch = { cmd = '''python -c "import cProfile; import numpy as np; from fastcan import minibatch; X = np.random.rand(100, 3000); y = np.random.rand(100, 20); cProfile.run('minibatch(X, y, 1000, 10, verbose=0)', sort='{{ SORT }}')"''', args = [{ arg = "SORT", default = "cumtime" }] }
9090
time-narx = '''python -m timeit -n 1 -s "import numpy as np; from fastcan.narx import make_narx; rng = np.random.default_rng(5); X = rng.random((1000, 10)); y = rng.random((1000, 2)); m = make_narx(X, y, 10, max_delay=2, poly_degree=2, verbose=0)" "m.fit(X, y, coef_init='one_step_ahead', verbose=1)"'''
9191
profile-narx = { cmd = '''python -c "import cProfile; import numpy as np; from fastcan.narx import make_narx; rng = np.random.default_rng(8); X = rng.random((3000, 3)); y = rng.random((3000, 3)); m = make_narx(X, y, 10, max_delay=10, poly_degree=2, verbose=0); cProfile.run('m.fit(X, y, coef_init=[0]*33)', sort='{{ SORT }}')"''', args = [{ arg = "SORT", default = "tottime" }] }

tests/test_minibatch.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import numpy as np
44
import pytest
55
from sklearn.cluster import KMeans
6-
from sklearn.datasets import load_iris, make_classification
6+
from sklearn.datasets import load_iris, make_classification, make_regression
77
from sklearn.preprocessing import OneHotEncoder
88

99
from fastcan import minibatch
10+
from fastcan.utils import ssc
1011

1112

1213
def test_data_pruning():
@@ -60,7 +61,7 @@ def test_select_minibatch_cls():
6061
assert indices.size == n_to_select
6162

6263

63-
def test_minibatch_error():
64+
def test_minibatch_error_warning():
6465
# Test refine raise error.
6566
n_samples = 200
6667
n_features = 20
@@ -83,3 +84,37 @@ def test_minibatch_error():
8384

8485
with pytest.raises(ValueError, match=r"n_features_to_select .*"):
8586
_ = minibatch(X, y, n_features + 1, batch_size=3)
87+
88+
Y = OneHotEncoder(sparse_output=False).fit_transform(y.reshape(-1, 1))
89+
Y[:, 0] = 1
90+
with pytest.warns(
91+
UserWarning, match=r"Contain constant targets, whose indices are .*"
92+
):
93+
_ = minibatch(X, Y, 5, batch_size=3)
94+
95+
96+
def test_minibatch_ssc_aligned(capsys):
97+
# Test whether ssc of minibatch aligns with the true ssc score
98+
n_features = 20
99+
n_targets = 5
100+
n_to_select = 10
101+
X, y = make_regression(
102+
n_samples=100,
103+
n_features=n_features,
104+
n_informative=10,
105+
n_targets=n_targets,
106+
noise=0.1,
107+
random_state=0,
108+
)
109+
110+
# The last batch of features are selected for the last target.
111+
# The number of features selected per target is n_to_select // n_targets
112+
n_features_per_target = n_to_select // n_targets
113+
indices = minibatch(X, y, n_to_select, batch_size=n_features_per_target + 1)
114+
captured = capsys.readouterr()
115+
116+
gtruth_ssc = ssc(X[:, indices[-n_features_per_target:]], y[:, [-1]])
117+
assert (
118+
f"Progress: {n_to_select}/{n_to_select}, "
119+
f"Batch SSC: {gtruth_ssc:.5f}" in captured.out
120+
)

0 commit comments

Comments
 (0)