We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2665d5d commit 2e408bcCopy full SHA for 2e408bc
skglm/solvers/fista.py
@@ -35,6 +35,7 @@ def __init__(self, max_iter=100, tol=1e-4, verbose=0):
35
self.verbose = verbose
36
37
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
38
+ p_objs_out = []
39
n_samples, n_features = X.shape
40
all_features = np.arange(n_features)
41
t_new = 1
@@ -66,8 +67,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
66
67
opt = penalty.subdiff_distance(w, grad, all_features)
68
stop_crit = np.max(opt)
69
70
+ p_obj = datafit.value(y, w, Xw) + penalty.value(w)
71
+ p_objs_out.append(p_obj)
72
if self.verbose:
- p_obj = datafit.value(y, w, Xw) + penalty.value(w)
73
print(
74
f"Iteration {n_iter+1}: {p_obj:.10f}, "
75
f"stopping crit: {stop_crit:.2e}"
@@ -77,4 +79,4 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
77
79
78
80
print(f"Stopping criterion max violation: {stop_crit:.2e}")
81
break
- return w
82
+ return w, np.array(p_objs_out), stop_crit
skglm/tests/test_fista.py
@@ -39,13 +39,13 @@ def test_fista_solver(X, Datafit, Penalty):
datafit.initialize(_init, _y)
penalty = compiled_clone(Penalty(alpha))
42
- solver = FISTA(max_iter=1000, tol=tol, opt_freq=1)
43
- w = solver.solve(X, _y, datafit, penalty)
+ solver = FISTA(max_iter=1000, tol=tol)
+ res_fista = solver.solve(X, _y, datafit, penalty)
44
45
solver_cd = AndersonCD(tol=tol, fit_intercept=False)
46
- w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]
+ res_cd = solver_cd.solve(X, _y, datafit, penalty)
47
48
- np.testing.assert_allclose(w, w_cd, rtol=1e-3)
+ np.testing.assert_allclose(res_fista[0], res_cd[0], rtol=1e-3)
49
50
51
if __name__ == '__main__':
0 commit comments