Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].2
- uses: prefix-dev/[email protected].3
with:
environments: default
cache: true
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _fastcan_pruning(
).fit(X)
atoms = kmeans.cluster_centers_
ids_fastcan = minibatch(
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, verbose=0
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, tol=1e-6, verbose=0
)
pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_fastcan], y[ids_fastcan])
return pruned_lr.coef_, pruned_lr.intercept_
Expand Down
17 changes: 13 additions & 4 deletions examples/plot_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):
# Run test
# --------

import time

N_SELECTORS = len(selector_dict)
n_missed = np.zeros((N_REPEATED, N_SELECTORS), dtype=int)
elapsed_time = np.zeros((N_REPEATED, N_SELECTORS), dtype=float)

for i in range(N_REPEATED):
data, target = make_redundant(
Expand All @@ -186,7 +189,9 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):
random_seed=i,
)
for j, selector in enumerate(selector_dict.values()):
start_time = time.time()
result_ids = selector.fit(data, target).get_support(indices=True)
elapsed_time[i, j] = time.time() - start_time
n_missed[i, j] = get_n_missed(
dep_info_ids=DEP_INFO_IDS,
indep_info_ids=INDEP_INFO_IDS,
Expand All @@ -202,10 +207,14 @@ def get_n_missed(dep_info_ids, indep_info_ids, redundant_ids, selected_ids):

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))
rects = ax.bar(selector_dict.keys(), n_missed.sum(0), width=0.5)
ax.bar_label(rects, n_missed.sum(0), padding=3)
fig = plt.figure(figsize=(8, 5))
ax1 = fig.add_subplot()
ax2 = ax1.twinx()
ax1.set_ylabel("No. of missed features")
ax2.set_ylabel("Elapsed time (s)")
rects = ax1.bar(selector_dict.keys(), n_missed.sum(0), width=0.5)
ax1.bar_label(rects, n_missed.sum(0), padding=3)
ax2.semilogy(selector_dict.keys(), elapsed_time.mean(0), marker="o", color="tab:orange")
plt.xlabel("Selector")
plt.ylabel("No. of missed features")
plt.title("Performance of selectors on datasets with linearly redundant features")
plt.show()
10 changes: 7 additions & 3 deletions fastcan/_minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# SPDX-License-Identifier: MIT

from copy import deepcopy
from numbers import Integral
from numbers import Integral, Real

import numpy as np
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
Expand All @@ -27,11 +27,12 @@
"batch_size": [
Interval(Integral, 1, None, closed="left"),
],
"tol": [Interval(Real, 0, None, closed="neither")],
"verbose": ["verbose"],
},
prefer_skip_nested_validation=True,
)
def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1):
"""Feature selection using :class:`fastcan.FastCan` with mini batches.

It is suitable for selecting a very large number of features
Expand Down Expand Up @@ -60,6 +61,9 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
The upper bound of the number of features in a mini-batch.
It is recommended that batch_size be less than n_samples.

tol : float, default=0.01
Tolerance for linear dependence check.

verbose : int, default=1
The verbosity level.

Expand Down Expand Up @@ -118,7 +122,7 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
X=deepcopy(X_transformed_),
V=y_i,
t=batch_size_temp,
tol=0.01,
tol=tol,
num_threads=n_threads,
verbose=0,
mask=mask,
Expand Down
Loading
Loading