Skip to content

Commit 80f08cd

Browse files
Merge pull request #43 from MatthewSZhang/beam
FEAT add tol to minibatch
2 parents 48df491 + 2a495e8 commit 80f08cd

File tree

5 files changed

+366
-440
lines changed

5 files changed

+366
-440
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313

1414
steps:
1515
- uses: actions/checkout@v4
16-
- uses: prefix-dev/[email protected].2
16+
- uses: prefix-dev/[email protected].3
1717
with:
1818
environments: default
1919
cache: true

examples/plot_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _fastcan_pruning(
7272
).fit(X)
7373
atoms = kmeans.cluster_centers_
7474
ids_fastcan = minibatch(
75-
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, verbose=0
75+
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, tol=1e-6, verbose=0
7676
)
7777
pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_fastcan], y[ids_fastcan])
7878
return pruned_lr.coef_, pruned_lr.intercept_

examples/plot_redundancy.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,11 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):
173173
# Run test
174174
# --------
175175

176+
import time
177+
176178
N_SELECTORS = len(selector_dict)
177179
n_missed = np.zeros((N_REPEATED, N_SELECTORS), dtype=int)
180+
elapsed_time = np.zeros((N_REPEATED, N_SELECTORS), dtype=float)
178181

179182
for i in range(N_REPEATED):
180183
data, target = make_redundant(
@@ -186,7 +189,9 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):
186189
random_seed=i,
187190
)
188191
for j, selector in enumerate(selector_dict.values()):
192+
start_time = time.time()
189193
result_ids = selector.fit(data, target).get_support(indices=True)
194+
elapsed_time[i, j] = time.time() - start_time
190195
n_missed[i, j] = get_n_missed(
191196
dep_info_ids=DEP_INFO_IDS,
192197
indep_info_ids=INDEP_INFO_IDS,
@@ -202,10 +207,14 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):
202207

203208
import matplotlib.pyplot as plt
204209

205-
fig, ax = plt.subplots(figsize=(8, 5))
206-
rects = ax.bar(selector_dict.keys(), n_missed.sum(0), width=0.5)
207-
ax.bar_label(rects, n_missed.sum(0), padding=3)
210+
fig = plt.figure(figsize=(8, 5))
211+
ax1 = fig.add_subplot()
212+
ax2 = ax1.twinx()
213+
ax1.set_ylabel("No. of missed features")
214+
ax2.set_ylabel("Elapsed time (s)")
215+
rects = ax1.bar(selector_dict.keys(), n_missed.sum(0), width=0.5)
216+
ax1.bar_label(rects, n_missed.sum(0), padding=3)
217+
ax2.semilogy(selector_dict.keys(), elapsed_time.mean(0), marker="o", color="tab:orange")
208218
plt.xlabel("Selector")
209-
plt.ylabel("No. of missed features")
210219
plt.title("Performance of selectors on datasets with linearly redundant features")
211220
plt.show()

fastcan/_minibatch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# SPDX-License-Identifier: MIT
77

88
from copy import deepcopy
9-
from numbers import Integral
9+
from numbers import Integral, Real
1010

1111
import numpy as np
1212
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
@@ -27,11 +27,12 @@
2727
"batch_size": [
2828
Interval(Integral, 1, None, closed="left"),
2929
],
30+
"tol": [Interval(Real, 0, None, closed="neither")],
3031
"verbose": ["verbose"],
3132
},
3233
prefer_skip_nested_validation=True,
3334
)
34-
def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
35+
def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
3536
"""Feature selection using :class:`fastcan.FastCan` with mini batches.
3637
3738
It is suitable for selecting a very large number of features
@@ -60,6 +61,9 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
6061
The upper bound of the number of features in a mini-batch.
6162
It is recommended that batch_size be less than n_samples.
6263
64+
tol : float, default=0.01
65+
Tolerance for linear dependence check.
66+
6367
verbose : int, default=1
6468
The verbosity level.
6569
@@ -118,7 +122,7 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
118122
X=deepcopy(X_transformed_),
119123
V=y_i,
120124
t=batch_size_temp,
121-
tol=0.01,
125+
tol=tol,
122126
num_threads=n_threads,
123127
verbose=0,
124128
mask=mask,

0 commit comments

Comments
 (0)