1818# ------------------------------
1919# We use ``iris`` dataset and logistic regression model to demonstrate data pruning.
2020# The baseline model is a logistic regression model trained on the entire dataset.
21+ # Here, 110 samples are used as the training data, which is intentionally made
22+ # imbalanced, to test data pruning methods.
2123# The coefficients of the model trained on the pruned dataset will be compared to
2224# the baseline model with R-squared score.
2325# The higher R-squared score, the better the pruning.
2426
2527from sklearn .datasets import load_iris
2628from sklearn .linear_model import LogisticRegression
2729
28- data , labels = load_iris (return_X_y = True )
29- baseline_lr = LogisticRegression (max_iter = 1000 ).fit (data , labels )
30+ iris = load_iris (as_frame = True )
31+ baseline_lr = LogisticRegression (max_iter = 1000 ).fit (iris ["data" ], iris ["target" ])
32+ X_train = iris ["data" ].values [10 :120 ]
33+ y_train = iris ["target" ].values [10 :120 ]
3034
3135# %%
3236# Random data pruning
3337# -------------------
34- # There are 150 samples in the dataset. The pruned dataset for
35- # random pruning method is selected randomly.
38+ # There are 110 samples in the training dataset.
39+ # The random pruning method selected samples assuming a uniform distribution
40+ # over all data.
3641
3742import numpy as np
3843
@@ -78,14 +83,65 @@ def _fastcan_pruning(
7883 return pruned_lr .coef_ , pruned_lr .intercept_
7984
8085
86+ # %%
87+ # Visualize selected samples
88+ # --------------------------------------------------
89+ # Use principal component analysis (PCA) to visualize the distribution of the samples,
90+ # and to compare the difference between the selection of ``Random`` pruning and
91+ # ``FastCan`` pruning.
92+ # For clearer viewing of the selection, only 10 samples are selected from the training
93+ # data by the pruning methods.
94+ # The results show that ``FastCan`` selects 3 setosa, 4 versicolor,
95+ # and 3 virginica, while ``Random`` select 6, 2, and 2, respectively.
96+ # The imbalanced selection of ``Random`` is caused by the imbalanced training data,
97+ # while ``FastCan``, benefited from the dictionary learning (k-means), can overcome
98+ # the imbalance issue.
99+
100+ import matplotlib .pyplot as plt
101+ from sklearn .decomposition import PCA
102+
103+
104+ def plot_pca (X , y , target_names , n_samples_to_select , random_state ):
105+ pca = PCA (2 ).fit (X )
106+ pcs_all = pca .transform (X )
107+
108+ kmeans = KMeans (
109+ n_clusters = 10 ,
110+ random_state = random_state ,
111+ ).fit (X )
112+ atoms = kmeans .cluster_centers_
113+ pcs_atoms = pca .transform (atoms )
114+
115+ ids_fastcan = minibatch (X .T , atoms .T , n_samples_to_select , batch_size = 1 , verbose = 0 )
116+ pcs_fastcan = pca .transform (X [ids_fastcan ])
117+
118+ rng = np .random .default_rng (random_state )
119+ ids_random = rng .choice (X .shape [0 ], n_samples_to_select , replace = False )
120+ pcs_random = pca .transform (X [ids_random ])
121+
122+ plt .scatter (pcs_fastcan [:, 0 ], pcs_fastcan [:, 1 ], s = 50 , marker = "o" , label = "FastCan" )
123+ plt .scatter (pcs_random [:, 0 ], pcs_random [:, 1 ], s = 50 , marker = "*" , label = "Random" )
124+ plt .scatter (pcs_atoms [:, 0 ], pcs_atoms [:, 1 ], s = 100 , marker = "+" , label = "Atoms" )
125+ cmap = plt .get_cmap ("Dark2" )
126+ for i , label in enumerate (target_names ):
127+ mask = y == i
128+ plt .scatter (
129+ pcs_all [mask , 0 ], pcs_all [mask , 1 ], s = 5 , label = label , color = cmap (i + 2 )
130+ )
131+ plt .xlabel ("The First Principle Component" )
132+ plt .ylabel ("The Second Principle Component" )
133+ plt .legend (ncol = 2 )
134+
135+
136+ plot_pca (X_train , y_train , iris .target_names , 10 , 123 )
137+
81138# %%
82139# Compare pruning methods
83140# -----------------------
84- # 100 samples are selected from 150 original data with ``Random`` pruning and
141+ # 80 samples are selected from 110 training data with ``Random`` pruning and
85142# ``FastCan`` pruning. The results show that ``FastCan`` pruning gives a higher
86- # mean value of R-squared and a lower standard deviation.
143+ # median value of R-squared and a lower standard deviation.
87144
88- import matplotlib .pyplot as plt
89145from sklearn .metrics import r2_score
90146
91147
@@ -94,7 +150,7 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
94150 r2_random = np .zeros (n_random )
95151 for i in range (n_random ):
96152 coef , intercept = _fastcan_pruning (
97- X , y , n_samples_to_select , i , n_atoms = 50 , batch_size = 2
153+ X , y , n_samples_to_select , i , n_atoms = 40 , batch_size = 2
98154 )
99155 r2_fastcan [i ] = r2_score (
100156 np .c_ [coef , intercept ], np .c_ [baseline .coef_ , baseline .intercept_ ]
@@ -111,4 +167,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
111167 plt .show ()
112168
113169
114- plot_box (data , labels , baseline_lr , n_samples_to_select = 100 , n_random = 100 )
170+ plot_box (X_train , y_train , baseline_lr , n_samples_to_select = 80 , n_random = 100 )
0 commit comments