18
18
# ------------------------------
19
19
# We use ``iris`` dataset and logistic regression model to demonstrate data pruning.
20
20
# The baseline model is a logistic regression model trained on the entire dataset.
21
+ # Here, 110 samples are used as the training data, which is intentionally made
22
+ # imbalanced, to test data pruning methods.
21
23
# The coefficients of the model trained on the pruned dataset will be compared to
22
24
# the baseline model with R-squared score.
23
25
# The higher R-squared score, the better the pruning.
24
26
25
27
from sklearn .datasets import load_iris
26
28
from sklearn .linear_model import LogisticRegression
27
29
28
- data , labels = load_iris (return_X_y = True )
29
- baseline_lr = LogisticRegression (max_iter = 1000 ).fit (data , labels )
30
+ iris = load_iris (as_frame = True )
31
+ baseline_lr = LogisticRegression (max_iter = 1000 ).fit (iris ["data" ], iris ["target" ])
32
+ X_train = iris ["data" ].values [10 :120 ]
33
+ y_train = iris ["target" ].values [10 :120 ]
30
34
31
35
# %%
32
36
# Random data pruning
33
37
# -------------------
34
- # There are 150 samples in the dataset. The pruned dataset for
35
- # random pruning method is selected randomly.
38
+ # There are 110 samples in the training dataset.
39
+ # The random pruning method selected samples assuming a uniform distribution
40
+ # over all data.
36
41
37
42
import numpy as np
38
43
@@ -78,14 +83,65 @@ def _fastcan_pruning(
78
83
return pruned_lr .coef_ , pruned_lr .intercept_
79
84
80
85
86
+ # %%
87
+ # Visualize selected samples
88
+ # --------------------------------------------------
89
+ # Use principal component analysis (PCA) to visualize the distribution of the samples,
90
+ # and to compare the difference between the selection of ``Random`` pruning and
91
+ # ``FastCan`` pruning.
92
+ # For clearer viewing of the selection, only 10 samples are selected from the training
93
+ # data by the pruning methods.
94
+ # The results show that ``FastCan`` selects 3 setosa, 4 versicolor,
95
+ # and 3 virginica, while ``Random`` select 6, 2, and 2, respectively.
96
+ # The imbalanced selection of ``Random`` is caused by the imbalanced training data,
97
+ # while ``FastCan``, benefited from the dictionary learning (k-means), can overcome
98
+ # the imbalance issue.
99
+
100
+ import matplotlib .pyplot as plt
101
+ from sklearn .decomposition import PCA
102
+
103
+
104
+ def plot_pca (X , y , target_names , n_samples_to_select , random_state ):
105
+ pca = PCA (2 ).fit (X )
106
+ pcs_all = pca .transform (X )
107
+
108
+ kmeans = KMeans (
109
+ n_clusters = 10 ,
110
+ random_state = random_state ,
111
+ ).fit (X )
112
+ atoms = kmeans .cluster_centers_
113
+ pcs_atoms = pca .transform (atoms )
114
+
115
+ ids_fastcan = minibatch (X .T , atoms .T , n_samples_to_select , batch_size = 1 , verbose = 0 )
116
+ pcs_fastcan = pca .transform (X [ids_fastcan ])
117
+
118
+ rng = np .random .default_rng (random_state )
119
+ ids_random = rng .choice (X .shape [0 ], n_samples_to_select , replace = False )
120
+ pcs_random = pca .transform (X [ids_random ])
121
+
122
+ plt .scatter (pcs_fastcan [:, 0 ], pcs_fastcan [:, 1 ], s = 50 , marker = "o" , label = "FastCan" )
123
+ plt .scatter (pcs_random [:, 0 ], pcs_random [:, 1 ], s = 50 , marker = "*" , label = "Random" )
124
+ plt .scatter (pcs_atoms [:, 0 ], pcs_atoms [:, 1 ], s = 100 , marker = "+" , label = "Atoms" )
125
+ cmap = plt .get_cmap ("Dark2" )
126
+ for i , label in enumerate (target_names ):
127
+ mask = y == i
128
+ plt .scatter (
129
+ pcs_all [mask , 0 ], pcs_all [mask , 1 ], s = 5 , label = label , color = cmap (i + 2 )
130
+ )
131
+ plt .xlabel ("The First Principle Component" )
132
+ plt .ylabel ("The Second Principle Component" )
133
+ plt .legend (ncol = 2 )
134
+
135
+
136
+ plot_pca (X_train , y_train , iris .target_names , 10 , 123 )
137
+
81
138
# %%
82
139
# Compare pruning methods
83
140
# -----------------------
84
- # 100 samples are selected from 150 original data with ``Random`` pruning and
141
+ # 80 samples are selected from 110 training data with ``Random`` pruning and
85
142
# ``FastCan`` pruning. The results show that ``FastCan`` pruning gives a higher
86
- # mean value of R-squared and a lower standard deviation.
143
+ # median value of R-squared and a lower standard deviation.
87
144
88
- import matplotlib .pyplot as plt
89
145
from sklearn .metrics import r2_score
90
146
91
147
@@ -94,7 +150,7 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
94
150
r2_random = np .zeros (n_random )
95
151
for i in range (n_random ):
96
152
coef , intercept = _fastcan_pruning (
97
- X , y , n_samples_to_select , i , n_atoms = 50 , batch_size = 2
153
+ X , y , n_samples_to_select , i , n_atoms = 40 , batch_size = 2
98
154
)
99
155
r2_fastcan [i ] = r2_score (
100
156
np .c_ [coef , intercept ], np .c_ [baseline .coef_ , baseline .intercept_ ]
@@ -111,4 +167,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
111
167
plt .show ()
112
168
113
169
114
- plot_box (data , labels , baseline_lr , n_samples_to_select = 100 , n_random = 100 )
170
+ plot_box (X_train , y_train , baseline_lr , n_samples_to_select = 80 , n_random = 100 )
0 commit comments