diff --git a/examples/plot_pruning.py b/examples/plot_pruning.py index b9d3fde..c329fdd 100644 --- a/examples/plot_pruning.py +++ b/examples/plot_pruning.py @@ -26,7 +26,7 @@ from sklearn.linear_model import LogisticRegression data, labels = load_iris(return_X_y=True) -baseline_lr = LogisticRegression(max_iter=110).fit(data, labels) +baseline_lr = LogisticRegression(max_iter=1000).fit(data, labels) # %% # Random data pruning @@ -40,7 +40,7 @@ def _random_pruning(X, y, n_samples_to_select: int, random_state: int): rng = np.random.default_rng(random_state) ids_random = rng.choice(y.size, n_samples_to_select, replace=False) - pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_random], y[ids_random]) + pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_random], y[ids_random]) return pruned_lr.coef_, pruned_lr.intercept_ @@ -72,9 +72,9 @@ def _fastcan_pruning( ).fit(X) atoms = kmeans.cluster_centers_ ids_fastcan = minibatch( - X.T, atoms.T, n_samples_to_select, batch_size=batch_size, tol=1e-9, verbose=0 + X.T, atoms.T, n_samples_to_select, batch_size=batch_size, verbose=0 ) - pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_fastcan], y[ids_fastcan]) + pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_fastcan], y[ids_fastcan]) print(atoms[-1], ids_fastcan[-10:]) return pruned_lr.coef_, pruned_lr.intercept_ @@ -112,4 +112,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int): plt.show() -plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=10) +plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=100)