Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions examples/plot_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@
# ------------------------------
# We use ``iris`` dataset and logistic regression model to demonstrate data pruning.
# The baseline model is a logistic regression model trained on the entire dataset.
# Here, 110 samples are used as the training data, which is intentionally made
# imbalanced, to test data pruning methods.
# The coefficients of the model trained on the pruned dataset will be compared to
# the baseline model with R-squared score.
# The higher R-squared score, the better the pruning.

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

data, labels = load_iris(return_X_y=True)
baseline_lr = LogisticRegression(max_iter=1000).fit(data, labels)
iris = load_iris(as_frame=True)
baseline_lr = LogisticRegression(max_iter=1000).fit(iris["data"], iris["target"])
X_train = iris["data"].values[10:120]
y_train = iris["target"].values[10:120]

# %%
# Random data pruning
# -------------------
# There are 150 samples in the dataset. The pruned dataset for
# random pruning method is selected randomly.
# There are 110 samples in the training dataset.
# The random pruning method selected samples assuming a uniform distribution
# over all data.

import numpy as np

Expand Down Expand Up @@ -78,14 +83,65 @@ def _fastcan_pruning(
return pruned_lr.coef_, pruned_lr.intercept_


# %%
# Visualize selected samples
# --------------------------------------------------
# Use principal component analysis (PCA) to visualize the distribution of the samples,
# and to compare the difference between the selection of ``Random`` pruning and
# ``FastCan`` pruning.
# For clearer viewing of the selection, only 10 samples are selected from the training
# data by the pruning methods.
# The results show that ``FastCan`` selects 3 setosa, 4 versicolor,
# and 3 virginica, while ``Random`` select 6, 2, and 2, respectively.
# The imbalanced selection of ``Random`` is caused by the imbalanced training data,
# while ``FastCan``, benefited from the dictionary learning (k-means), can overcome
# the imbalance issue.

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


def plot_pca(X, y, target_names, n_samples_to_select, random_state):
pca = PCA(2).fit(X)
pcs_all = pca.transform(X)

kmeans = KMeans(
n_clusters=10,
random_state=random_state,
).fit(X)
atoms = kmeans.cluster_centers_
pcs_atoms = pca.transform(atoms)

ids_fastcan = minibatch(X.T, atoms.T, n_samples_to_select, batch_size=1, verbose=0)
pcs_fastcan = pca.transform(X[ids_fastcan])

rng = np.random.default_rng(random_state)
ids_random = rng.choice(X.shape[0], n_samples_to_select, replace=False)
pcs_random = pca.transform(X[ids_random])

plt.scatter(pcs_fastcan[:, 0], pcs_fastcan[:, 1], s=50, marker="o", label="FastCan")
plt.scatter(pcs_random[:, 0], pcs_random[:, 1], s=50, marker="*", label="Random")
plt.scatter(pcs_atoms[:, 0], pcs_atoms[:, 1], s=100, marker="+", label="Atoms")
cmap = plt.get_cmap("Dark2")
for i, label in enumerate(target_names):
mask = y == i
plt.scatter(
pcs_all[mask, 0], pcs_all[mask, 1], s=5, label=label, color=cmap(i + 2)
)
plt.xlabel("The First Principle Component")
plt.ylabel("The Second Principle Component")
plt.legend(ncol=2)


plot_pca(X_train, y_train, iris.target_names, 10, 123)

# %%
# Compare pruning methods
# -----------------------
# 100 samples are selected from 150 original data with ``Random`` pruning and
# 80 samples are selected from 110 training data with ``Random`` pruning and
# ``FastCan`` pruning. The results show that ``FastCan`` pruning gives a higher
# mean value of R-squared and a lower standard deviation.
# median value of R-squared and a lower standard deviation.

import matplotlib.pyplot as plt
from sklearn.metrics import r2_score


Expand All @@ -94,7 +150,7 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
r2_random = np.zeros(n_random)
for i in range(n_random):
coef, intercept = _fastcan_pruning(
X, y, n_samples_to_select, i, n_atoms=50, batch_size=2
X, y, n_samples_to_select, i, n_atoms=40, batch_size=2
)
r2_fastcan[i] = r2_score(
np.c_[coef, intercept], np.c_[baseline.coef_, baseline.intercept_]
Expand All @@ -111,4 +167,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
plt.show()


plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=100)
plot_box(X_train, y_train, baseline_lr, n_samples_to_select=80, n_random=100)
Loading