Skip to content

Commit 3fddec5

Browse files
Merge pull request #86 from MatthewSZhang/prune-pca
DOC add pca plot for data pruning example
2 parents 67a7af8 + 20c3780 commit 3fddec5

File tree

2 files changed

+167
-117
lines changed

2 files changed

+167
-117
lines changed

examples/plot_pruning.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,26 @@
1818
# ------------------------------
1919
# We use ``iris`` dataset and logistic regression model to demonstrate data pruning.
2020
# The baseline model is a logistic regression model trained on the entire dataset.
21+
# Here, 110 samples are used as the training data, which is intentionally made
22+
# imbalanced, to test data pruning methods.
2123
# The coefficients of the model trained on the pruned dataset will be compared to
2224
# the baseline model with R-squared score.
2325
# The higher R-squared score, the better the pruning.
2426

2527
from sklearn.datasets import load_iris
2628
from sklearn.linear_model import LogisticRegression
2729

28-
data, labels = load_iris(return_X_y=True)
29-
baseline_lr = LogisticRegression(max_iter=1000).fit(data, labels)
30+
iris = load_iris(as_frame=True)
31+
baseline_lr = LogisticRegression(max_iter=1000).fit(iris["data"], iris["target"])
32+
X_train = iris["data"].values[10:120]
33+
y_train = iris["target"].values[10:120]
3034

3135
# %%
3236
# Random data pruning
3337
# -------------------
34-
# There are 150 samples in the dataset. The pruned dataset for
35-
# random pruning method is selected randomly.
38+
# There are 110 samples in the training dataset.
39+
# The random pruning method selected samples assuming a uniform distribution
40+
# over all data.
3641

3742
import numpy as np
3843

@@ -78,14 +83,65 @@ def _fastcan_pruning(
7883
return pruned_lr.coef_, pruned_lr.intercept_
7984

8085

86+
# %%
87+
# Visualize selected samples
88+
# --------------------------------------------------
89+
# Use principal component analysis (PCA) to visualize the distribution of the samples,
90+
# and to compare the difference between the selection of ``Random`` pruning and
91+
# ``FastCan`` pruning.
92+
# For clearer viewing of the selection, only 10 samples are selected from the training
93+
# data by the pruning methods.
94+
# The results show that ``FastCan`` selects 3 setosa, 4 versicolor,
95+
# and 3 virginica, while ``Random`` select 6, 2, and 2, respectively.
96+
# The imbalanced selection of ``Random`` is caused by the imbalanced training data,
97+
# while ``FastCan``, benefited from the dictionary learning (k-means), can overcome
98+
# the imbalance issue.
99+
100+
import matplotlib.pyplot as plt
101+
from sklearn.decomposition import PCA
102+
103+
104+
def plot_pca(X, y, target_names, n_samples_to_select, random_state):
105+
pca = PCA(2).fit(X)
106+
pcs_all = pca.transform(X)
107+
108+
kmeans = KMeans(
109+
n_clusters=10,
110+
random_state=random_state,
111+
).fit(X)
112+
atoms = kmeans.cluster_centers_
113+
pcs_atoms = pca.transform(atoms)
114+
115+
ids_fastcan = minibatch(X.T, atoms.T, n_samples_to_select, batch_size=1, verbose=0)
116+
pcs_fastcan = pca.transform(X[ids_fastcan])
117+
118+
rng = np.random.default_rng(random_state)
119+
ids_random = rng.choice(X.shape[0], n_samples_to_select, replace=False)
120+
pcs_random = pca.transform(X[ids_random])
121+
122+
plt.scatter(pcs_fastcan[:, 0], pcs_fastcan[:, 1], s=50, marker="o", label="FastCan")
123+
plt.scatter(pcs_random[:, 0], pcs_random[:, 1], s=50, marker="*", label="Random")
124+
plt.scatter(pcs_atoms[:, 0], pcs_atoms[:, 1], s=100, marker="+", label="Atoms")
125+
cmap = plt.get_cmap("Dark2")
126+
for i, label in enumerate(target_names):
127+
mask = y == i
128+
plt.scatter(
129+
pcs_all[mask, 0], pcs_all[mask, 1], s=5, label=label, color=cmap(i + 2)
130+
)
131+
plt.xlabel("The First Principle Component")
132+
plt.ylabel("The Second Principle Component")
133+
plt.legend(ncol=2)
134+
135+
136+
plot_pca(X_train, y_train, iris.target_names, 10, 123)
137+
81138
# %%
82139
# Compare pruning methods
83140
# -----------------------
84-
# 100 samples are selected from 150 original data with ``Random`` pruning and
141+
# 80 samples are selected from 110 training data with ``Random`` pruning and
85142
# ``FastCan`` pruning. The results show that ``FastCan`` pruning gives a higher
86-
# mean value of R-squared and a lower standard deviation.
143+
# median value of R-squared and a lower standard deviation.
87144

88-
import matplotlib.pyplot as plt
89145
from sklearn.metrics import r2_score
90146

91147

@@ -94,7 +150,7 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
94150
r2_random = np.zeros(n_random)
95151
for i in range(n_random):
96152
coef, intercept = _fastcan_pruning(
97-
X, y, n_samples_to_select, i, n_atoms=50, batch_size=2
153+
X, y, n_samples_to_select, i, n_atoms=40, batch_size=2
98154
)
99155
r2_fastcan[i] = r2_score(
100156
np.c_[coef, intercept], np.c_[baseline.coef_, baseline.intercept_]
@@ -111,4 +167,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
111167
plt.show()
112168

113169

114-
plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=100)
170+
plot_box(X_train, y_train, baseline_lr, n_samples_to_select=80, n_random=100)

0 commit comments

Comments
 (0)