Skip to content

Commit 48df491

Browse files
Merge pull request #40 from MatthewSZhang/batch-doc
DOC minibatch data pruning docs
2 parents 293f6be + cf270c9 commit 48df491

24 files changed

+1410
-1183
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313

1414
steps:
1515
- uses: actions/checkout@v4
16-
- uses: prefix-dev/[email protected].1
16+
- uses: prefix-dev/[email protected].2
1717
with:
1818
environments: default
1919
cache: true

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2024 SIKAI ZHANG
3+
Copyright (c) 2024-2025 The fastcan developers.
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# sys.path.insert(0, os.path.abspath(".."))
1515
# General information about the project.
1616
project = "fastcan"
17-
copyright = f"{datetime.now().year}, fastcan developers (MIT License)"
17+
copyright = f"2024 - {datetime.now().year}, fastcan developers (MIT License)"
1818

1919
# The version info for the project you're documenting, acts as replacement for
2020
# |version| and |release|, also used in various other places throughout the
@@ -25,7 +25,7 @@
2525
release = importlib.metadata.version(project)
2626

2727
# The short X.Y version.
28-
version = '.'.join(release.split('.')[:2])
28+
version = ".".join(release.split(".")[:2])
2929

3030
# -- General configuration ---------------------------------------------------
3131
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

doc/narx.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ It should also be noted the different types of predictions in model training.
7777
and the model can be trained by the simple ordinary least-squares (OLS) method
7878
#. If assume the NARX model is a multiple-step-ahead prediction structure, the input data, like :math:`\hat{y}(k-1)` is
7979
unknown in advance. Therefore, the training data must first be generated by the multiple-step-ahead prediction with
80-
the initial model coefficients, and then the coefficients can be updated recursively.
80+
the initial model coefficients, and then the coefficients can be updated recursively
8181

8282
ARX and OE model
8383
----------------

doc/pruning.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
.. currentmodule:: fastcan
2+
3+
.. _pruning:
4+
5+
===================================================
6+
Dictionary learning based unsupervised data pruning
7+
===================================================
8+
9+
Different from feature selection, which reduces the size of dataset in column-wise,
10+
data pruning reduces the size of dataset in row-wise.
11+
To use :class:`FastCan` for unsupervised data pruning, the target :math:`Y` matrix is
12+
obtained first with `dictionary learning <https://scikit-learn.org/stable/modules/decomposition.html#dictionary-learning>`_.
13+
Dictionary learning will learn a ``dictionary`` which is composed of atoms.
14+
The atoms should be very representative, so that each sample of dataset can be represented (with errors)
15+
by sparse linear combinations of the atoms.
16+
We use these atoms as the target :math:`Y` and select samples based on their correlation with :math:`Y`.
17+
18+
One challenge to use :class:`FastCan` for data pruning is that the number to select is much larger than feature selection.
19+
Normally, this number is higher than the number of features, which will make the pruned data matrix singular.
20+
In other words, :class:`FastCan` will easily think the pruned data is redundant and no additional sample
21+
should be selected, as any additional samples can be represented by linear combinations of the selected samples.
22+
Therefore, the number to select has to be set to small.
23+
24+
To solve this problem, we use :func:`minibatch` to loose the redundancy check of :class:`FastCan`.
25+
The original :class:`FastCan` checks the redunancy within :math:`X_s \in \mathbb{R}^{n\times t}`,
26+
which contains :math:`t` selected samples and n features,
27+
and the redunancy within :math:`Y \in \mathbb{R}^{n\times m}`, which contains :math:`m` atoms :math:`y_i`.
28+
:func:`minibatch` ranks samples with multiple correlation coefficients between :math:`X_b \in \mathbb{R}^{n\times b}` and :math:`y_i`,
29+
where :math:`b` is batch size and :math:`b <= t`, instead of canonical correlation coefficients between :math:`X_s` and :math:`Y`,
30+
which is used in :class:`FastCan`.
31+
Therefore, :func:`minibatch` looses the redundancy check in two ways.
32+
33+
#. it uses :math:`y_i` instead of :math:`Y`, so no redundancy check is performed within :math:`Y`
34+
#. it uses :math:`X_b` instead of :math:`X_s`, so :func:`minibatch` only checks the redundancy within a batch :math:`X_b`, but does not
35+
check the redundancy between batches.
36+
37+
38+
.. rubric:: References
39+
40+
* `"Dictionary-learning-based data pruning for system identification"
41+
<https://doi.org/10.48550/arXiv.2502.11484>`_
42+
Wang, T., Zhang, S., & Sun L.
43+
arXiv (2025).
44+
45+
46+
.. rubric:: Examples
47+
48+
* See :ref:`sphx_glr_auto_examples_plot_pruning.py` for an example of dictionary learning based data pruning.

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ User Guide
1313
redundancy.rst
1414
ols_and_omp.rst
1515
narx.rst
16+
pruning.rst

examples/plot_affinity.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
selection methods on affine transformed features.
1010
"""
1111

12-
# Authors: Sikai Zhang
12+
# Authors: The fastcan developers
1313
# SPDX-License-Identifier: MIT
1414

1515
# %%
@@ -41,7 +41,6 @@
4141
print("FastCan: ", np.sort(ids_fastcan))
4242

4343

44-
4544
# %%
4645
# Affine transformation
4746
# ---------------------
@@ -50,7 +49,6 @@
5049
# three features from the polluted features. The more stable the result, the better.
5150

5251

53-
5452
n_features = X.shape[1]
5553
rng = np.random.default_rng()
5654

@@ -75,7 +73,7 @@
7573

7674
import matplotlib.pyplot as plt
7775

78-
bin_lims = np.arange(n_features+1)
76+
bin_lims = np.arange(n_features + 1)
7977
counts_omp, _ = np.histogram(ids_omp_all, bins=bin_lims)
8078
counts_ols, _ = np.histogram(ids_ols_all, bins=bin_lims)
8179
counts_fastcan, _ = np.histogram(ids_fastcan_all, bins=bin_lims)

examples/plot_fisher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
relationship with Fisher's criterion in LDA (Linear Discriminant Analysis).
1111
"""
1212

13-
# Authors: Sikai Zhang
13+
# Authors: The fastcan developers
1414
# SPDX-License-Identifier: MIT
1515

1616
# %%
@@ -49,11 +49,11 @@
4949
fishers_criterion, _ = linalg.eigh(Sb, Sw)
5050

5151
fishers_criterion = np.sort(fishers_criterion)[::-1]
52-
n_nonzero = min(X.shape[1], clf.classes_.shape[0]-1)
52+
n_nonzero = min(X.shape[1], clf.classes_.shape[0] - 1)
5353
# remove the eigenvalues which are close to zero
5454
fishers_criterion = fishers_criterion[:n_nonzero]
5555
# get canonical correlation coefficients from convert Fisher's criteria
56-
r2 = fishers_criterion/(1+fishers_criterion)
56+
r2 = fishers_criterion / (1 + fishers_criterion)
5757

5858
# %%
5959
# Compute SSC

examples/plot_intuitive.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
in :class:`FastCan`.
1010
"""
1111

12-
# Authors: Sikai Zhang
12+
# Authors: The fastcan developers
1313
# SPDX-License-Identifier: MIT
1414

1515
# %%
@@ -28,7 +28,6 @@
2828
# property, so that the usefullness of each feature can be added together without
2929
# redundancy.
3030

31-
3231
import matplotlib.pyplot as plt
3332
import numpy as np
3433
from matplotlib.patches import Patch
@@ -37,8 +36,9 @@
3736

3837
from fastcan import FastCan
3938

40-
plt.rcParams['axes.spines.right'] = False
41-
plt.rcParams['axes.spines.top'] = False
39+
plt.rcParams["axes.spines.right"] = False
40+
plt.rcParams["axes.spines.top"] = False
41+
4242

4343
def get_r2(feats, target, feats_selected=None):
4444
"""Get R-squared between [feats_selected, feat_i] and target."""
@@ -54,36 +54,38 @@ def get_r2(feats, target, feats_selected=None):
5454
r2[i] = lr.fit(feats_i, target).score(feats_i, target)
5555
return r2
5656

57+
5758
def plot_bars(ids, r2_left, r2_selected):
5859
"""Plot the relative R-squared with a bar plot."""
59-
legend_selected = Patch(color='tab:green', label='X_selected')
60-
legend_cand = Patch(color='tab:blue', label='x_i: candidates')
61-
legend_best = Patch(color='tab:orange', label='Best candidate')
60+
legend_selected = Patch(color="tab:green", label="X_selected")
61+
legend_cand = Patch(color="tab:blue", label="x_i: candidates")
62+
legend_best = Patch(color="tab:orange", label="Best candidate")
6263
n_features = len(ids)
6364
n_selected = len(r2_selected)
6465

65-
left = np.zeros(n_features)+sum(r2_selected)
66+
left = np.zeros(n_features) + sum(r2_selected)
6667
left_selected = np.cumsum(r2_selected)
6768
left_selected = np.r_[0, left_selected]
6869
left_selected = left_selected[:-1]
6970
left[:n_selected] = left_selected
7071

71-
label = [""]*n_features
72-
label[np.argmax(r2_left)+n_selected] = f"{max(r2_left):.5f}"
72+
label = [""] * n_features
73+
label[np.argmax(r2_left) + n_selected] = f"{max(r2_left):.5f}"
7374

74-
colors = ["tab:blue"]*(n_features - n_selected)
75+
colors = ["tab:blue"] * (n_features - n_selected)
7576
colors[np.argmax(r2_left)] = "tab:orange"
76-
colors = ["tab:green"]*n_selected + colors
77+
colors = ["tab:green"] * n_selected + colors
7778

7879
hbars = plt.barh(ids, width=np.r_[score_selected, r2_left], color=colors, left=left)
79-
plt.axvline(x = sum(r2_selected), color = 'tab:orange', linestyle="--")
80+
plt.axvline(x=sum(r2_selected), color="tab:orange", linestyle="--")
8081
plt.bar_label(hbars, label)
8182
plt.yticks(np.arange(n_features))
8283
plt.xlabel("R-squared between [X_selected, x_i] and y")
8384
plt.ylabel("Feature index")
8485
plt.legend(handles=[legend_selected, legend_cand, legend_best])
8586
plt.show()
8687

88+
8789
X, y = load_diabetes(return_X_y=True)
8890

8991

@@ -92,7 +94,6 @@ def plot_bars(ids, r2_left, r2_selected):
9294
score_selected = []
9395

9496

95-
9697
score_0 = get_r2(X, y)
9798

9899
plot_bars(id_left, score_0, score_selected)
@@ -114,13 +115,12 @@ def plot_bars(ids, r2_left, r2_selected):
114115
id_selected += [id_left[index]]
115116
score_selected += [score_0[index]]
116117
id_left = np.delete(id_left, index)
117-
score_1 = get_r2(X[:, id_left], y, X[:, id_selected])-sum(score_selected)
118+
score_1 = get_r2(X[:, id_left], y, X[:, id_selected]) - sum(score_selected)
118119

119120

120121
plot_bars(np.r_[id_selected, id_left], score_1, score_selected)
121122

122123

123-
124124
# %%
125125
# Select the third feature
126126
# ------------------------
@@ -133,12 +133,11 @@ def plot_bars(ids, r2_left, r2_selected):
133133
id_selected += [id_left[index]]
134134
score_selected += [score_1[index]]
135135
id_left = np.delete(id_left, index)
136-
score_2 = get_r2(X[:, id_left], y, X[:, id_selected])-sum(score_selected)
136+
score_2 = get_r2(X[:, id_left], y, X[:, id_selected]) - sum(score_selected)
137137

138138
plot_bars(np.r_[id_selected, id_left], score_2, score_selected)
139139

140140

141-
142141
# %%
143142
# h-correlation and eta-cosine
144143
# ----------------------------
@@ -180,7 +179,7 @@ def plot_bars(ids, r2_left, r2_selected):
180179
score_selected = [score_0[index]]
181180
id_left = np.arange(X.shape[1])
182181
id_left = np.delete(id_left, index)
183-
score_1_7 = get_r2(X[:, id_left], y, X[:, id_selected])-sum(score_selected)
182+
score_1_7 = get_r2(X[:, id_left], y, X[:, id_selected]) - sum(score_selected)
184183

185184
plot_bars(np.r_[id_selected, id_left], score_1_7, score_selected)
186185

examples/plot_narx.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
NARX model for time series prediction.
1010
"""
1111

12-
# Authors: Sikai Zhang
12+
# Authors: The fastcan developers
1313
# SPDX-License-Identifier: MIT
1414

1515
# %%
@@ -26,19 +26,24 @@
2626
# :math:`u_0` and :math:`u_1` are input signals,
2727
# and :math:`y` is the output signal.
2828

29-
3029
import numpy as np
3130

3231
rng = np.random.default_rng(12345)
3332
n_samples = 1000
3433
max_delay = 3
3534
e = rng.normal(0, 0.1, n_samples)
36-
u0 = rng.uniform(0, 1, n_samples+max_delay)
37-
u1 = rng.normal(0, 0.1, n_samples+max_delay)
38-
y = np.zeros(n_samples+max_delay)
39-
for i in range(max_delay, n_samples+max_delay):
40-
y[i] = 0.5*y[i-1]+0.3*u0[i]**2+2*u0[i-1]*u0[i-3]+1.5*u0[i-2]*u1[i-3]+1
41-
y = y[max_delay:]+e
35+
u0 = rng.uniform(0, 1, n_samples + max_delay)
36+
u1 = rng.normal(0, 0.1, n_samples + max_delay)
37+
y = np.zeros(n_samples + max_delay)
38+
for i in range(max_delay, n_samples + max_delay):
39+
y[i] = (
40+
0.5 * y[i - 1]
41+
+ 0.3 * u0[i] ** 2
42+
+ 2 * u0[i - 1] * u0[i - 3]
43+
+ 1.5 * u0[i - 2] * u1[i - 3]
44+
+ 1
45+
)
46+
y = y[max_delay:] + e
4247
X = np.c_[u0[max_delay:], u1[max_delay:]]
4348

4449
# %%
@@ -75,9 +80,9 @@
7580
from fastcan.narx import make_time_shift_features, make_time_shift_ids
7681

7782
time_shift_ids = make_time_shift_ids(
78-
n_features=3, # Number of inputs (2) and output (1) signals
79-
max_delay=3, # Maximum time delays
80-
include_zero_delay = [True, True, False], # Whether to include zero delay
83+
n_features=3, # Number of inputs (2) and output (1) signals
84+
max_delay=3, # Maximum time delays
85+
include_zero_delay=[True, True, False], # Whether to include zero delay
8186
# for each signal. The output signal should not have zero delay.
8287
)
8388

@@ -90,8 +95,8 @@
9095
from fastcan.narx import make_poly_features, make_poly_ids
9196

9297
poly_ids = make_poly_ids(
93-
n_features=time_shift_vars.shape[1], # Number of time-shifted variables
94-
degree=2, # Maximum polynomial degree
98+
n_features=time_shift_vars.shape[1], # Number of time-shifted variables
99+
degree=2, # Maximum polynomial degree
95100
)
96101

97102
poly_terms = make_poly_features(time_shift_vars, poly_ids)
@@ -105,7 +110,7 @@
105110
from fastcan import FastCan
106111

107112
selector = FastCan(
108-
n_features_to_select=4, # 4 terms should be selected
113+
n_features_to_select=4, # 4 terms should be selected
109114
).fit(poly_terms, y)
110115

111116
support = selector.get_support()
@@ -124,7 +129,7 @@
124129

125130
narx_model = NARX(
126131
time_shift_ids=time_shift_ids,
127-
poly_ids = selected_poly_ids,
132+
poly_ids=selected_poly_ids,
128133
)
129134

130135
narx_model.fit(X, y)
@@ -158,7 +163,7 @@
158163

159164
y_pred = narx_model.predict(
160165
X[:100],
161-
y_init=y[:narx_model.max_delay_] # Set the initial values of the prediction to
166+
y_init=y[: narx_model.max_delay_], # Set the initial values of the prediction to
162167
# the true values
163168
)
164169

0 commit comments

Comments
 (0)