Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 35c2263

Browse files
samuelefiorinimblondel
authored andcommitted
Stepsize update rule changed in fista.py (#102)
* Stepsize update rule changed in fista.py According to the original FISTA paper [Beck09]. [Beck09] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding algorithm for linear inverse problems." SIAM journal on imaging sciences 2.1 (2009): 183-202. * test_fista.py edit desired values changed after correcting fista stepsize update rule * test_fista.py corrected desired values of the asserts
1 parent 9bc01cb commit 35c2263

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

lightning/impl/fista.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _fit(self, X, y, n_vectors):
116116
coefx = coef - G / L
117117
coefx = penalty.projection(coefx, self.alpha, L)
118118

119-
t = (1 + np.sqrt(1 + 4 * t_old * t_old) / 2)
119+
t = (1 + np.sqrt(1 + 4 * t_old * t_old)) / 2
120120
coef = coefx + (t_old - 1) / t * (coefx - coefx_old)
121121
df = safe_sparse_dot(X, coef.T)
122122

lightning/impl/tests/test_fista.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def test_fista_multiclass_l1l2_log_margin():
4646
clf = FistaClassifier(max_iter=200, penalty="l1/l2", loss="log_margin",
4747
multiclass=True)
4848
clf.fit(data, mult_target)
49-
assert_almost_equal(clf.score(data, mult_target), 0.95, 2)
49+
assert_almost_equal(clf.score(data, mult_target), 0.93, 2)
5050

5151

5252
def test_fista_multiclass_l1():
5353
for data in (mult_dense, mult_csr):
5454
clf = FistaClassifier(max_iter=200, penalty="l1", multiclass=True)
5555
clf.fit(data, mult_target)
56-
assert_almost_equal(clf.score(data, mult_target), 0.98)
56+
assert_almost_equal(clf.score(data, mult_target), 0.98, 2)
5757

5858

5959

@@ -76,15 +76,15 @@ def test_fista_multiclass_l1l2_no_line_search():
7676
clf = FistaClassifier(max_iter=500, penalty="l1/l2", multiclass=True,
7777
max_steps=0)
7878
clf.fit(data, mult_target)
79-
assert_almost_equal(clf.score(data, mult_target), 0.96, 2)
79+
assert_almost_equal(clf.score(data, mult_target), 0.94, 2)
8080

8181

8282
def test_fista_multiclass_l1_no_line_search():
8383
for data in (mult_dense, mult_csr):
8484
clf = FistaClassifier(max_iter=500, penalty="l1", multiclass=True,
8585
max_steps=0)
8686
clf.fit(data, mult_target)
87-
assert_almost_equal(clf.score(data, mult_target), 0.95, 2)
87+
assert_almost_equal(clf.score(data, mult_target), 0.94, 2)
8888

8989

9090
def test_fista_bin_l1():
@@ -105,7 +105,7 @@ def test_fista_multiclass_trace():
105105
for data in (mult_dense, mult_csr):
106106
clf = FistaClassifier(max_iter=100, penalty="trace", multiclass=True)
107107
clf.fit(data, mult_target)
108-
assert_almost_equal(clf.score(data, mult_target), 0.98, 2)
108+
assert_almost_equal(clf.score(data, mult_target), 0.96, 2)
109109

110110

111111
def test_fista_bin_classes():
@@ -175,7 +175,7 @@ def _make_data(n_samples, n_features, n_tasks, n_components):
175175
Y_pred = reg.predict(X)
176176
error = (Y_pred - Y).ravel()
177177
error = np.dot(error, error)
178-
assert_almost_equal(error, 77.45, 2)
178+
assert_almost_equal(error, 77.44, 2)
179179

180180

181181
def test_fista_custom_prox():

0 commit comments

Comments
 (0)