Skip to content

Commit 2e408bc

Browse files
committed
fix tests
1 parent 2665d5d commit 2e408bc

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

skglm/solvers/fista.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, max_iter=100, tol=1e-4, verbose=0):
3535
self.verbose = verbose
3636

3737
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
38+
p_objs_out = []
3839
n_samples, n_features = X.shape
3940
all_features = np.arange(n_features)
4041
t_new = 1
@@ -66,8 +67,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6667
opt = penalty.subdiff_distance(w, grad, all_features)
6768
stop_crit = np.max(opt)
6869

70+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
71+
p_objs_out.append(p_obj)
6972
if self.verbose:
70-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
7173
print(
7274
f"Iteration {n_iter+1}: {p_obj:.10f}, "
7375
f"stopping crit: {stop_crit:.2e}"
@@ -77,4 +79,4 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
7779
if self.verbose:
7880
print(f"Stopping criterion max violation: {stop_crit:.2e}")
7981
break
80-
return w
82+
return w, np.array(p_objs_out), stop_crit

skglm/tests/test_fista.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def test_fista_solver(X, Datafit, Penalty):
3939
datafit.initialize(_init, _y)
4040
penalty = compiled_clone(Penalty(alpha))
4141

42-
solver = FISTA(max_iter=1000, tol=tol, opt_freq=1)
43-
w = solver.solve(X, _y, datafit, penalty)
42+
solver = FISTA(max_iter=1000, tol=tol)
43+
res_fista = solver.solve(X, _y, datafit, penalty)
4444

4545
solver_cd = AndersonCD(tol=tol, fit_intercept=False)
46-
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]
46+
res_cd = solver_cd.solve(X, _y, datafit, penalty)
4747

48-
np.testing.assert_allclose(w, w_cd, rtol=1e-3)
48+
np.testing.assert_allclose(res_fista[0], res_cd[0], rtol=1e-3)
4949

5050

5151
if __name__ == '__main__':

0 commit comments

Comments
 (0)