Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastcan/_fastcan.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class FastCan(SelectorMixin, BaseEstimator):
def __init__(
self,
n_features_to_select=1,
*,
indices_include=None,
indices_exclude=None,
eta=False,
Expand Down
74 changes: 36 additions & 38 deletions fastcan/_minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.utils.validation import check_X_y

from ._cancorr_fast import _forward_search # type: ignore
from ._fastcan import FastCan, _prepare_search
from ._fastcan import _prepare_search


@validate_params(
Expand All @@ -26,17 +26,21 @@
],
"verbose": ["verbose"],
},
prefer_skip_nested_validation=False,
prefer_skip_nested_validation=True,
)
def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
"""Feature selection using :class:`fastcan.FastCan` with mini batches.

It is suitable for selecting a very large number of features
even larger than the number of samples.

Similar to the correlation filter which selects each feature without considering
the redundancy, the function selects features in mini-batch and the
redundancy between the two mini-batches will be ignored.
The function splits `n_features_to_select` into `n_outputs` parts and selects
features for each part separately, ignoring the redundancy among outputs.
In each part, the function selects features batch-by-batch. The batch size is less
than or equal to `batch_size`.
Like correlation filters, which select features one-by-one without considering
the redundancy between two features, the function ignores the redundancy between
two mini-batches.

Parameters
----------
Expand Down Expand Up @@ -70,7 +74,7 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
>>> print(f"Indices: {indices}")
Indices: [0 1 2]
"""
X, y = check_X_y(X, y, ensure_2d=True, multi_output=True)
X, y = check_X_y(X, y, ensure_2d=True, multi_output=True, order="F")
if y.ndim == 1:
y = y.reshape(-1, 1)

Expand All @@ -90,41 +94,35 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, verbose=1):
0, n_features_to_select, num=n_outputs + 1, endpoint=True, dtype=int
)
)
X_transformed_ = X - X.mean(0)
y_transformed_ = y - y.mean(0)
indices_include = np.zeros(0, dtype=int) # just an empty array
indices_select = np.zeros(0, dtype=int)

for i in range(n_outputs):
y_i = y[:, i]
batch_split_i = np.diff(
np.r_[
np.arange(n_to_select_split[i], step=batch_size, dtype=int),
n_to_select_split[i],
]
)
for j, batch_size_j in enumerate(batch_split_i):
if j == 0:
selector_j = FastCan(
batch_size_j, indices_exclude=indices_select, verbose=0
).fit(X, y_i)
X_transformed_ = deepcopy(selector_j.X_transformed_)
indices = selector_j.indices_
else:
indices, scores, mask = _prepare_search(
n_features,
batch_size_j,
selector_j.indices_include_,
np.r_[selector_j.indices_exclude_, indices_select],
)
_forward_search(
X=X_transformed_,
V=selector_j.y_transformed_,
t=batch_size_j,
tol=selector_j.tol,
num_threads=n_threads,
verbose=0,
mask=mask,
indices=indices,
scores=scores,
)
y_i = y_transformed_[:, [i]]
n_selected_i = 0
while n_to_select_split[i] > n_selected_i:
batch_size_temp = min(batch_size, n_to_select_split[i] - n_selected_i)
indices, scores, mask = _prepare_search(
n_features,
batch_size_temp,
indices_include,
indices_select,
)
_forward_search(
X=deepcopy(X_transformed_),
V=y_i,
t=batch_size_temp,
tol=0.01,
num_threads=n_threads,
verbose=0,
mask=mask,
indices=indices,
scores=scores,
)
indices_select = np.r_[indices_select, indices]
n_selected_i += batch_size_temp
if verbose == 1:
print(
f"Progress: {indices_select.size}/{n_features_to_select}", end="\r"
Expand Down
2 changes: 1 addition & 1 deletion meson.build
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
project(
'fastcan',
'c', 'cython',
version: '0.3.1',
version: '0.3.2',
license: 'MIT',
meson_version: '>= 1.1.0',
default_options: [
Expand Down
Loading
Loading