Skip to content

Commit 7537dd0

Browse files
Merge pull request #22 from MatthewSZhang/extend
FEAT add extend by mini-batch
2 parents f72d11d + 57d5c6b commit 7537dd0

File tree

8 files changed

+221
-2
lines changed

8 files changed

+221
-2
lines changed

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ API Reference
1919

2020
FastCan
2121
refine
22+
extend
2223
ssc
2324
ols
2425

fastcan/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
The :mod:`fastcan` module implements algorithms, including
33
"""
44

5+
from ._extend import extend
56
from ._fastcan import FastCan
67
from ._refine import refine
78
from ._utils import ols, ssc
@@ -11,4 +12,5 @@
1112
"ssc",
1213
"ols",
1314
"refine",
15+
"extend",
1416
]

fastcan/_cancorr_fast.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ cpdef int _forward_search(
194194

195195
# Find max scores and update indices, X, mask, and scores
196196
index = _iamax(n_features, &r2[0], 1)
197+
if r2[index] == 0:
198+
raise RuntimeError(
199+
f"No improvement can be found when selecting the {i}th feature."
200+
)
197201
indices[i] = index
198202
scores[i] = r2[index]
199203

fastcan/_extend.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""
2+
Extend feature selection
3+
"""
4+
5+
import math
6+
from copy import deepcopy
7+
from numbers import Integral
8+
9+
import numpy as np
10+
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
11+
from sklearn.utils._param_validation import Interval, validate_params
12+
from sklearn.utils.validation import check_is_fitted
13+
14+
from ._cancorr_fast import _forward_search # type: ignore
15+
from ._fastcan import FastCan, _prepare_search
16+
17+
18+
@validate_params(
19+
{
20+
"selector": [FastCan],
21+
"n_features_to_select": [
22+
Interval(Integral, 1, None, closed="left"),
23+
],
24+
"batch_size": [
25+
Interval(Integral, 1, None, closed="left"),
26+
],
27+
},
28+
prefer_skip_nested_validation=False,
29+
)
30+
def extend(selector, n_features_to_select=1, batch_size=1):
31+
"""Extend FastCan with mini batches.
32+
33+
It is suitable for selecting a very large number of features
34+
even larger than the number of samples.
35+
36+
Similar to the correlation filter which selects each feature without considering
37+
the redundancy, the function selects features in mini-batch and the
38+
redundancy between the two mini-batches will be ignored.
39+
40+
Parameters
41+
----------
42+
selector : FastCan
43+
FastCan selector.
44+
45+
n_features_to_select : int, default=1
46+
The parameter is the absolute number of features to select.
47+
48+
batch_size : int, default=1
49+
The number of features in a mini-batch.
50+
51+
Returns
52+
-------
53+
indices : ndarray of shape (n_features_to_select,), dtype=int
54+
The indices of the selected features.
55+
56+
Examples
57+
--------
58+
>>> from fastcan import FastCan, extend
59+
>>> X = [[1, 1, 0], [0.01, 0, 0], [-1, 0, 1], [0, 0, 0]]
60+
>>> y = [1, 0, -1, 0]
61+
>>> selector = FastCan(1, verbose=0).fit(X, y)
62+
>>> print(f"Indices: {selector.indices_}")
63+
Indices: [0]
64+
>>> indices = extend(selector, 3, batch_size=2)
65+
>>> print(f"Indices: {indices}")
66+
Indices: [0 2 1]
67+
"""
68+
check_is_fitted(selector)
69+
n_inclusions = selector.indices_include_.size
70+
n_features = selector.n_features_in_
71+
n_to_select = n_features_to_select - selector.n_features_to_select
72+
batch_size_to_select = batch_size - n_inclusions
73+
74+
if n_features_to_select > n_features:
75+
raise ValueError(
76+
f"n_features_to_select {n_features_to_select} "
77+
f"must be <= n_features {n_features}."
78+
)
79+
if n_to_select <= 0:
80+
raise ValueError(
81+
f"The number of features to select ({n_to_select}) ", "is less than 0."
82+
)
83+
if batch_size_to_select <= 0:
84+
raise ValueError(
85+
"The size of mini batch without included indices ",
86+
f"({batch_size_to_select}) is less than 0.",
87+
)
88+
89+
X_transformed_ = deepcopy(selector.X_transformed_)
90+
91+
indices_include = selector.indices_include_
92+
indices_exclude = selector.indices_exclude_
93+
indices_select = selector.indices_[n_inclusions:]
94+
95+
n_threads = _openmp_effective_n_threads()
96+
97+
for i in range(math.ceil(n_to_select / batch_size_to_select)):
98+
if i == 0:
99+
batch_size_i = (n_to_select - 1) % batch_size_to_select + 1 + n_inclusions
100+
else:
101+
batch_size_i = batch_size
102+
indices, scores, mask = _prepare_search(
103+
n_features,
104+
batch_size_i,
105+
indices_include,
106+
np.r_[indices_exclude, indices_select],
107+
)
108+
_forward_search(
109+
X=X_transformed_,
110+
V=selector.y_transformed_,
111+
t=batch_size_i,
112+
tol=selector.tol,
113+
num_threads=n_threads,
114+
verbose=0,
115+
mask=mask,
116+
indices=indices,
117+
scores=scores,
118+
)
119+
indices_select = np.r_[indices_select, indices[n_inclusions:]]
120+
return np.r_[indices_include, indices_select]

fastcan/_fastcan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ class FastCan(SelectorMixin, BaseEstimator):
7777
When h-correlation method is used, `n_samples_` = n_samples.
7878
When eta-cosine method is used, `n_samples_` = n_features+n_outputs.
7979
80+
indices_include_ : ndarray of shape (n_inclusions,), dtype=int
81+
The indices of the prerequisite features.
82+
83+
indices_exclude_ : array-like of shape (n_exclusions,), dtype=int
84+
The indices of the excluded features.
85+
8086
References
8187
----------
8288
* Zhang, S., & Lang, Z. Q. (2022).

fastcan/_refine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def refine(selector, drop=1, max_iter=None, verbose=1):
9393

9494
n_inclusions = indices_include.size
9595
n_selections = n_features_to_select - n_inclusions
96+
n_threads = _openmp_effective_n_threads()
9697

9798
if drop == "all":
9899
drop = np.arange(1, n_selections)
@@ -126,7 +127,6 @@ def refine(selector, drop=1, max_iter=None, verbose=1):
126127
rolled_indices[:-drop_n],
127128
indices_exclude,
128129
)
129-
n_threads = _openmp_effective_n_threads()
130130
_forward_search(
131131
X=X_transformed_,
132132
V=selector.y_transformed_,

tests/test_extend.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Test feature selection extend"""
2+
import numpy as np
3+
import pytest
4+
from numpy.testing import (
5+
assert_array_equal,
6+
)
7+
from sklearn.datasets import make_classification
8+
9+
from fastcan import FastCan, extend
10+
11+
12+
def test_select_extend_cls():
13+
# Test whether refine work correctly with random samples.
14+
n_samples = 200
15+
n_features = 30
16+
n_informative = 20
17+
n_classes = 8
18+
n_repeated = 5
19+
n_to_select = 18
20+
21+
X, y = make_classification(
22+
n_samples=n_samples,
23+
n_features=n_features,
24+
n_informative=n_informative,
25+
n_repeated=n_repeated,
26+
n_classes=n_classes,
27+
n_clusters_per_class=1,
28+
flip_y=0.0,
29+
class_sep=10,
30+
shuffle=False,
31+
random_state=0,
32+
)
33+
34+
n_features_to_select = 2
35+
selector = FastCan(n_features_to_select).fit(X, y)
36+
indices = extend(selector, n_to_select, batch_size=3)
37+
selector_inc = FastCan(n_features_to_select, indices_include=[10]).fit(X, y)
38+
indices_inc = extend(selector_inc, n_to_select, batch_size=3)
39+
selector_exc = FastCan(
40+
n_features_to_select, indices_include=[10], indices_exclude=[0]
41+
).fit(X, y)
42+
indices_exc = extend(selector_exc, n_to_select, batch_size=3)
43+
44+
45+
assert np.unique(indices).size == n_to_select
46+
assert_array_equal(indices[:n_features_to_select], selector.indices_)
47+
assert np.unique(indices_inc).size == n_to_select
48+
assert_array_equal(indices_inc[:n_features_to_select], selector_inc.indices_)
49+
assert np.unique(indices_exc).size == n_to_select
50+
assert_array_equal(indices_exc[:n_features_to_select], selector_exc.indices_)
51+
assert ~np.isin(0, indices_exc)
52+
53+
54+
def test_extend_error():
55+
# Test refine raise error.
56+
n_samples = 200
57+
n_features = 20
58+
n_informative = 10
59+
n_classes = 8
60+
n_repeated = 5
61+
62+
X, y = make_classification(
63+
n_samples=n_samples,
64+
n_features=n_features,
65+
n_informative=n_informative,
66+
n_repeated=n_repeated,
67+
n_classes=n_classes,
68+
n_clusters_per_class=1,
69+
flip_y=0.0,
70+
class_sep=10,
71+
shuffle=False,
72+
random_state=0,
73+
)
74+
75+
n_features_to_select = 2
76+
77+
selector = FastCan(n_features_to_select, indices_include=[0]).fit(X, y)
78+
79+
with pytest.raises(ValueError, match=r"n_features_to_select .*"):
80+
_ = extend(selector, n_features+1, batch_size=3)
81+
82+
with pytest.raises(ValueError, match=r"The number of features to select .*"):
83+
_ = extend(selector, n_features_to_select, batch_size=3)
84+
85+
with pytest.raises(ValueError, match=r"The size of mini batch without .*"):
86+
_ = extend(selector, n_features, batch_size=1)

tests/test_refine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastcan import FastCan, refine
66

77

8-
def test_select_refine_random_cls():
8+
def test_select_refine_cls():
99
# Test whether refine work correctly with random samples.
1010
n_samples = 200
1111
n_features = 20

0 commit comments

Comments
 (0)