Skip to content

Commit 085a225

Browse files
FEAT add beam search (#187)
1 parent e47b6fa commit 085a225

File tree

11 files changed

+285
-41
lines changed

11 files changed

+285
-41
lines changed

.github/workflows/asv.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
pixi run asv-build ${{ matrix.os }}
4343
4444
- name: Upload benchmark results
45-
uses: actions/upload-artifact@v4
45+
uses: actions/upload-artifact@v5
4646
with:
4747
name: asv-results-${{ matrix.os }}
4848
path: asv_benchmarks/results
@@ -82,7 +82,7 @@ jobs:
8282
cp -r gh-pages/results/* asv_benchmarks/results/ 2>/dev/null || true
8383
8484
- name: Download all benchmark results
85-
uses: actions/download-artifact@v5
85+
uses: actions/download-artifact@v6
8686
with:
8787
pattern: asv-results-*
8888

.github/workflows/emscripten.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
env:
1414
CIBW_PLATFORM: pyodide
1515
- name: Upload package
16-
uses: actions/upload-artifact@v4
16+
uses: actions/upload-artifact@v5
1717
with:
1818
name: wasm_wheel
1919
path: ./wheelhouse/*_wasm32.whl

.github/workflows/publish-pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
id-token: write
1919
steps:
2020
- name: Download artifacts
21-
uses: actions/download-artifact@v5
21+
uses: actions/download-artifact@v6
2222
with:
2323
path: dist/
2424
merge-multiple: true
@@ -29,7 +29,7 @@ jobs:
2929
uses: pypa/gh-action-pypi-publish@release/v1
3030

3131
- name: get wasm dist artifacts
32-
uses: actions/download-artifact@v5
32+
uses: actions/download-artifact@v6
3333
with:
3434
name: wasm_wheel
3535
path: wasm/

.github/workflows/wheel.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
run: |
2020
pixi run build-sdist
2121
- name: Store artifacts
22-
uses: actions/upload-artifact@v4
22+
uses: actions/upload-artifact@v5
2323
with:
2424
name: cibw-sdist
2525
path: dist/*.tar.gz
@@ -43,7 +43,7 @@ jobs:
4343
# Include free-threaded support
4444
CIBW_ENABLE: cpython-freethreading
4545
- name: Upload package
46-
uses: actions/upload-artifact@v4
46+
uses: actions/upload-artifact@v5
4747
with:
4848
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
4949
path: ./wheelhouse/*.whl

fastcan/_beam.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
Beam search.
3+
"""
4+
5+
# Authors: The fastcan developers
6+
# SPDX-License-Identifier: MIT
7+
8+
import numpy as np
9+
from scipy.linalg import orth
10+
11+
12+
def _beam_search(
13+
X, V, n_features_to_select, beam_width, indices_include, mask_exclude, tol, verbose
14+
):
15+
"""
16+
Perform beam search to find the best subset of features.
17+
18+
Parameters:
19+
X : np.ndarray
20+
The transformed input feature matrix.
21+
V : np.ndarray
22+
The transformed target variable.
23+
n_features_to_select : int
24+
The total number of features to select.
25+
beam_width : int
26+
The number of top candidates to keep at each step.
27+
indices_include : list
28+
The indices of features that must be included in the selection.
29+
mask_exclude : np.ndarray, dtype=bool
30+
A boolean mask indicating which features to exclude.
31+
tol : float
32+
Tolerance for numerical stability in Gram-Schmidt process.
33+
verbose : bool
34+
If True, print progress information.
35+
36+
Returns:
37+
indices : np.ndarray, dtype=np.int32
38+
The indices of the selected features.
39+
"""
40+
41+
n_features = X.shape[1]
42+
n_inclusions = len(indices_include)
43+
44+
X, _ = _safe_normalize(X)
45+
46+
for i in range(n_features_to_select - n_inclusions):
47+
if i == 0:
48+
X_support, X_selected = _prepare_candidates(
49+
X, mask_exclude, indices_include
50+
)
51+
beams_selected_ids = [indices_include for _ in range(beam_width)]
52+
W_selected = orth(X_selected)
53+
selected_score = np.sum((W_selected.T @ V) ** 2)
54+
beams_scores = _gram_schmidt(
55+
X, X_support, X_selected, selected_score, V, tol
56+
)
57+
beams_selected_ids, top_k_scores = _select_top_k(
58+
beams_scores[None, :],
59+
beams_selected_ids,
60+
beam_width,
61+
)
62+
continue
63+
beams_scores = np.zeros((beam_width, n_features))
64+
for beam_idx in range(beam_width):
65+
X_support, X_selected = _prepare_candidates(
66+
X, mask_exclude, beams_selected_ids[beam_idx]
67+
)
68+
beams_scores[beam_idx] = _gram_schmidt(
69+
X, X_support, X_selected, top_k_scores[beam_idx], V, tol
70+
)
71+
beams_selected_ids, top_k_scores = _select_top_k(
72+
beams_scores,
73+
beams_selected_ids,
74+
beam_width,
75+
)
76+
if verbose:
77+
print(
78+
f"Beam Search: {i + 1 + n_inclusions}/{n_features_to_select}, "
79+
f"Best Beam: {np.argmax(top_k_scores)}, "
80+
f"Beam SSC: {top_k_scores.max():.5f}",
81+
end="\r",
82+
)
83+
if verbose:
84+
print()
85+
best_beam = np.argmax(top_k_scores)
86+
indices = beams_selected_ids[best_beam]
87+
return np.array(indices, dtype=np.int32, order="F")
88+
89+
90+
def _prepare_candidates(X, mask_exclude, indices_selected):
91+
X_support = np.copy(~mask_exclude)
92+
X_support[indices_selected] = False
93+
X_selected = X[:, indices_selected]
94+
return X_support, X_selected
95+
96+
97+
def _select_top_k(
98+
beams_scores,
99+
ids_selected,
100+
beam_width,
101+
):
102+
# For explore wider: make each feature in each selection iteration can
103+
# only be selected once.
104+
# For explore deeper: allow different beams to select the same feature
105+
# at the different selection iteration.
106+
n_features = beams_scores.shape[1]
107+
beams_max = np.argmax(beams_scores, axis=0)
108+
scores_max = beams_scores[beams_max, np.arange(n_features)]
109+
n_valid = np.sum(beams_scores.any(axis=0))
110+
n_selected = len(ids_selected[0])
111+
if n_valid < beam_width:
112+
raise ValueError(
113+
"Beam Search: Not enough valid candidates to select "
114+
f"beam width number of features, when selecting feature {n_selected + 1}. "
115+
"Please decrease beam_width or n_features_to_select."
116+
)
117+
118+
top_k_ids = np.argpartition(scores_max, -beam_width)[-beam_width:]
119+
new_ids_selected = [[] for _ in range(beam_width)]
120+
for k, beam_k in enumerate(beams_max[top_k_ids]):
121+
new_ids_selected[k] = ids_selected[beam_k] + [top_k_ids[k]]
122+
top_k_scores = scores_max[top_k_ids]
123+
return new_ids_selected, top_k_scores
124+
125+
126+
def _gram_schmidt(X, X_support, X_selected, selected_score, V, tol, modified=True):
127+
X = np.copy(X)
128+
if modified:
129+
# Change to Modified Gram-Schmidt
130+
W_selected = orth(X_selected)
131+
scores = np.zeros(X.shape[1])
132+
for i, support in enumerate(X_support):
133+
if not support:
134+
continue
135+
xi = X[:, i : i + 1]
136+
for j in range(W_selected.shape[1]):
137+
proj = W_selected[:, j : j + 1].T @ xi
138+
xi -= proj * W_selected[:, j : j + 1]
139+
wi, X_support[i] = _safe_normalize(xi)
140+
if not X_support[i]:
141+
continue
142+
if np.any(np.abs(wi.T @ W_selected) > tol):
143+
X_support[i] = False
144+
continue
145+
scores[i] = np.sum((wi.T @ V) ** 2)
146+
scores += selected_score
147+
scores[~X_support] = 0
148+
return scores
149+
150+
151+
def _safe_normalize(X):
152+
norm = np.linalg.norm(X, axis=0)
153+
non_zero_support = norm != 0
154+
norm[~non_zero_support] = 1
155+
return X / norm, non_zero_support

fastcan/_cancorr_fast.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ cdef void _mgsvv(
125125

126126

127127
@final
128-
cpdef int _forward_search(
128+
cpdef int _greedy_search(
129129
floating[::1, :] X, # IN/OUT
130130
floating[::1, :] V, # IN
131131
int t, # IN

fastcan/_fastcan.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Authors: The fastcan developers
66
# SPDX-License-Identifier: MIT
77

8-
from copy import deepcopy
98
from numbers import Integral, Real
109

1110
import numpy as np
@@ -17,7 +16,8 @@
1716
from sklearn.utils._param_validation import Interval
1817
from sklearn.utils.validation import check_is_fitted, validate_data
1918

20-
from ._cancorr_fast import _forward_search # type: ignore[attr-defined]
19+
from ._beam import _beam_search
20+
from ._cancorr_fast import _greedy_search # type: ignore[attr-defined]
2121

2222

2323
class FastCan(SelectorMixin, BaseEstimator):
@@ -46,6 +46,13 @@ class FastCan(SelectorMixin, BaseEstimator):
4646
the feature `x` is linear dependent to the selected features,
4747
and `mask` for that feature will True.
4848
49+
beam_width : int, default=1
50+
The beam width for beam search.
51+
When `beam_width` = 1, use greedy search.
52+
When `beam_width` > 1, use beam search.
53+
54+
.. versionadded:: 0.5
55+
4956
verbose : int, default=1
5057
The verbosity level.
5158
@@ -114,6 +121,9 @@ class FastCan(SelectorMixin, BaseEstimator):
114121
"indices_exclude": [None, "array-like"],
115122
"eta": ["boolean"],
116123
"tol": [Interval(Real, 0, None, closed="neither")],
124+
"beam_width": [
125+
Interval(Integral, 1, None, closed="left"),
126+
],
117127
"verbose": ["verbose"],
118128
}
119129

@@ -125,13 +135,15 @@ def __init__(
125135
indices_exclude=None,
126136
eta=False,
127137
tol=0.01,
138+
beam_width=1,
128139
verbose=1,
129140
):
130141
self.n_features_to_select = n_features_to_select
131142
self.indices_include = indices_include
132143
self.indices_exclude = indices_exclude
133144
self.eta = eta
134145
self.tol = tol
146+
self.beam_width = beam_width
135147
self.verbose = verbose
136148

137149
def fit(self, X, y):
@@ -204,15 +216,16 @@ def fit(self, X, y):
204216
"`indices_include` and `indices_exclude` should not have intersection."
205217
)
206218

207-
n_candidates = (
208-
n_features - self.indices_exclude_.size - self.n_features_to_select
209-
)
210-
if n_candidates < 0:
219+
if (
220+
n_features - self.indices_exclude_.size
221+
< self.n_features_to_select + self.beam_width - 1
222+
):
211223
raise ValueError(
212-
"n_features - n_features_to_select - n_exclusions should >= 0."
224+
"n_features - n_exclusions should >= "
225+
"n_features_to_select + beam_width - 1."
213226
)
214-
if self.n_features_to_select - self.indices_include_.size < 0:
215-
raise ValueError("n_features_to_select - n_inclusions should >= 0.")
227+
if self.n_features_to_select < self.indices_include_.size:
228+
raise ValueError("n_features_to_select should >= n_inclusions.")
216229

217230
if self.eta:
218231
xy_hstack = np.hstack((X, y))
@@ -235,9 +248,28 @@ def fit(self, X, y):
235248
self.indices_exclude_,
236249
)
237250

251+
if self.beam_width > 1:
252+
indices = _beam_search(
253+
X=self.X_transformed_.copy(order="F"),
254+
V=self.y_transformed_,
255+
n_features_to_select=self.n_features_to_select,
256+
beam_width=self.beam_width,
257+
indices_include=list(self.indices_include_.copy()),
258+
mask_exclude=mask.astype(bool, copy=True),
259+
tol=self.tol,
260+
verbose=self.verbose,
261+
)
262+
263+
indices, scores, mask = _prepare_search(
264+
n_features,
265+
self.n_features_to_select,
266+
indices,
267+
self.indices_exclude_,
268+
)
269+
238270
n_threads = _openmp_effective_n_threads()
239-
_forward_search(
240-
X=deepcopy(self.X_transformed_),
271+
_greedy_search(
272+
X=self.X_transformed_.copy(order="F"),
241273
V=self.y_transformed_,
242274
t=self.n_features_to_select,
243275
tol=self.tol,

fastcan/_minibatch.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
# Authors: The fastcan developers
66
# SPDX-License-Identifier: MIT
77

8-
from copy import deepcopy
98
from numbers import Integral, Real
109

1110
import numpy as np
1211
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1312
from sklearn.utils._param_validation import Interval, validate_params
1413
from sklearn.utils.validation import check_X_y
1514

16-
from ._cancorr_fast import _forward_search # type: ignore[attr-defined]
15+
from ._cancorr_fast import _greedy_search # type: ignore[attr-defined]
1716
from ._fastcan import _prepare_search
1817

1918

@@ -118,8 +117,8 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
118117
indices_select,
119118
)
120119
try:
121-
_forward_search(
122-
X=deepcopy(X_transformed_),
120+
_greedy_search(
121+
X=np.copy(X_transformed_, order="F"),
123122
V=y_i,
124123
t=batch_size_temp,
125124
tol=tol,
@@ -130,7 +129,7 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
130129
scores=scores,
131130
)
132131
except RuntimeError:
133-
# If the batch size is too large, _forward_search cannot find enough
132+
# If the batch size is too large, _greedy_search cannot find enough
134133
# samples to form a non-singular matrix. Then, reduce the batch size.
135134
indices = indices[indices != -1]
136135
batch_size_temp = indices.size

0 commit comments

Comments
 (0)