Skip to content

Commit 5c2dbe3

Browse files
ENH speed up beam search (#198)
1 parent 509ee9c commit 5c2dbe3

File tree

3 files changed

+25
-33
lines changed

3 files changed

+25
-33
lines changed

fastcan/_beam.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ def _beam_search(
4545

4646
for i in range(n_features_to_select - n_inclusions):
4747
if i == 0:
48-
X_support, X_selected = _prepare_candidates(
48+
mask, X_selected = _prepare_candidates(
4949
X, mask_exclude, indices_include
5050
)
5151
if X_selected.shape[1] == 0:
5252
beams_scores = np.sum((X.T @ V) ** 2, axis=1)
53-
beams_scores[~X_support] = 0
53+
beams_scores[mask] = 0
5454
else:
5555
W_selected = orth(X_selected)
5656
selected_score = np.sum((W_selected.T @ V) ** 2)
57-
beams_scores = _gram_schmidt(
58-
X, X_support, X_selected, selected_score, V, tol
57+
beams_scores = _mgs_ssc(
58+
X, V, W_selected, mask, selected_score, tol
5959
)
6060
beams_selected_ids = [indices_include for _ in range(beam_width)]
6161
beams_selected_ids, top_k_scores = _select_top_k(
@@ -66,11 +66,12 @@ def _beam_search(
6666
continue
6767
beams_scores = np.zeros((beam_width, n_features))
6868
for beam_idx in range(beam_width):
69-
X_support, X_selected = _prepare_candidates(
69+
mask, X_selected = _prepare_candidates(
7070
X, mask_exclude, beams_selected_ids[beam_idx]
7171
)
72-
beams_scores[beam_idx] = _gram_schmidt(
73-
X, X_support, X_selected, top_k_scores[beam_idx], V, tol
72+
W_selected = orth(X_selected)
73+
beams_scores[beam_idx] = _mgs_ssc(
74+
X, V, W_selected, mask, top_k_scores[beam_idx], tol
7475
)
7576
beams_selected_ids, top_k_scores = _select_top_k(
7677
beams_scores,
@@ -92,10 +93,10 @@ def _beam_search(
9293

9394

9495
def _prepare_candidates(X, mask_exclude, indices_selected):
95-
X_support = np.copy(~mask_exclude)
96-
X_support[indices_selected] = False
96+
mask = np.copy(mask_exclude)
97+
mask[indices_selected] = True
9798
X_selected = X[:, indices_selected]
98-
return X_support, X_selected
99+
return mask, X_selected
99100

100101

101102
def _select_top_k(
@@ -127,31 +128,22 @@ def _select_top_k(
127128
return new_ids_selected, top_k_scores
128129

129130

130-
def _gram_schmidt(X, X_support, X_selected, selected_score, V, tol):
131+
def _mgs_ssc(X, V, W_selected, mask, selected_score, tol):
131132
X = np.copy(X)
132-
W_selected = orth(X_selected) # Change to Modified Gram-Schmidt
133-
scores = np.zeros(X.shape[1])
134-
for i, support in enumerate(X_support):
135-
if not support:
136-
continue
137-
xi = X[:, i : i + 1]
138-
for j in range(W_selected.shape[1]):
139-
proj = W_selected[:, j : j + 1].T @ xi
140-
xi -= proj * W_selected[:, j : j + 1]
141-
wi, X_support[i] = _safe_normalize(xi)
142-
if not X_support[i]:
143-
continue
144-
if np.any(np.abs(wi.T @ W_selected) > tol):
145-
X_support[i] = False
146-
continue
147-
scores[i] = np.sum((wi.T @ V) ** 2)
133+
proj = W_selected.T @ X
134+
X -= W_selected @ proj
135+
W, non_zero_mask = _safe_normalize(X)
136+
mask |= non_zero_mask
137+
linear_independent_mask = np.any(np.abs(W.T @ W_selected) > tol, axis=1)
138+
mask |= linear_independent_mask
139+
scores = np.sum((W.T @ V) ** 2, axis=1)
148140
scores += selected_score
149-
scores[~X_support] = 0
141+
scores[mask] = 0
150142
return scores
151143

152144

153145
def _safe_normalize(X):
154146
norm = np.linalg.norm(X, axis=0)
155-
non_zero_support = norm != 0
156-
norm[~non_zero_support] = 1
157-
return X / norm, non_zero_support
147+
non_zero_mask = norm == 0
148+
norm[non_zero_mask] = 1
149+
return X / norm, non_zero_mask

fastcan/_cancorr_fast.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ cdef int _iamax(
4747

4848

4949
@final
50-
cdef bint _normv(
50+
cdef uint8_t _normv(
5151
const floating* x, # IN/OUT
5252
int n_samples, # IN
5353
) noexcept nogil:

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ time-eta = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import Fa
8989
profile-minibatch = { cmd = '''python -c "import cProfile; import numpy as np; from fastcan import minibatch; X = np.random.rand(100, 3000); y = np.random.rand(100, 20); cProfile.run('minibatch(X, y, 1000, 10, verbose=0)', sort='{{ SORT }}')"''', args = [{ arg = "SORT", default = "cumtime" }] }
9090
time-narx = '''python -m timeit -n 1 -s "import numpy as np; from fastcan.narx import make_narx; rng = np.random.default_rng(5); X = rng.random((1000, 10)); y = rng.random((1000, 2)); m = make_narx(X, y, 10, max_delay=2, poly_degree=2, verbose=0)" "m.fit(X, y, coef_init='one_step_ahead', verbose=1)"'''
9191
profile-narx = { cmd = '''python -c "import cProfile; import numpy as np; from fastcan.narx import make_narx; rng = np.random.default_rng(8); X = rng.random((3000, 3)); y = rng.random((3000, 3)); m = make_narx(X, y, 10, max_delay=10, poly_degree=2, verbose=0); cProfile.run('m.fit(X, y, coef_init=[0]*33)', sort='{{ SORT }}')"''', args = [{ arg = "SORT", default = "tottime" }] }
92-
time-beam = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(3000, 100); y = np.random.rand(3000, 20)' 's = FastCan(20, beam_width=3, verbose=0).fit(X, y)'"
92+
time-beam = "python -m timeit -n 5 -s 'import numpy as np; from fastcan import FastCan; X = np.random.rand(3000, 100); y = np.random.rand(3000, 20)' 's = FastCan(20, beam_width=10, verbose=0).fit(X, y)'"
9393

9494
[feature.asv.tasks]
9595
asv-build = { cmd = "python -m asv machine --machine {{ MACHINE }} --yes && python -m asv run --show-stderr -v --machine {{ MACHINE }} {{ EXTRA_ARGS }}", cwd = "asv_benchmarks", args = [{ arg = "MACHINE", default = "MacOS-M1" }, { arg = "EXTRA_ARGS", default = "" }] }

0 commit comments

Comments
 (0)