Skip to content

Commit 9b7f3ef

Browse files
authored
FIX compute p_boj_acc only if acceleration is performed (#41)
1 parent f5e8154 commit 9b7f3ef

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

skglm/solvers/group_bcd_solver.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
9191
# inplace update of w and Xw
9292
_bcd_epoch(X, y, w, Xw, datafit, penalty, ws)
9393

94-
w_acc, Xw_acc = accelerator.extrapolate(w, Xw)
95-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
96-
p_obj_acc = datafit.value(y, w_acc, Xw_acc) + penalty.value(w_acc)
94+
w_acc, Xw_acc, is_extrapolated = accelerator.extrapolate(w, Xw)
95+
96+
if is_extrapolated: # avoid computing p_obj for un-extrapolated w, Xw
97+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
98+
p_obj_acc = datafit.value(y, w_acc, Xw_acc) + penalty.value(w_acc)
9799

98-
if p_obj_acc < p_obj:
99-
w, Xw = w_acc, Xw_acc
100-
p_obj = p_obj_acc
100+
if p_obj_acc < p_obj:
101+
w[:], Xw[:] = w_acc, Xw_acc
102+
p_obj = p_obj_acc
101103

102104
# check sub-optimality every 10 epochs
103105
if epoch % 10 == 0:
@@ -106,13 +108,15 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
106108
stop_crit_in = np.max(opt_in)
107109

108110
if max(verbose - 1, 0):
111+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
109112
print(
110113
f"Epoch {epoch + 1}, objective {p_obj:.10f}, "
111114
f"stopping crit {stop_crit_in:.2e}"
112115
)
113116

114117
if stop_crit_in <= 0.3 * stop_crit:
115118
break
119+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
116120
p_objs_out[t] = p_obj
117121

118122
return w, p_objs_out, stop_crit

skglm/tests/test_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_anderson_acceleration():
137137
w = np.ones(n_features)
138138
Xw = X @ w
139139
for i in range(max_iter):
140-
w, Xw = acc.extrapolate(w, Xw)
140+
w, Xw, _ = acc.extrapolate(w, Xw)
141141
w = rho * w + 1
142142
Xw = X @ w
143143

skglm/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __init__(self, K):
273273
self.arr_w_, self.arr_Xw_ = None, None
274274

275275
def extrapolate(self, w, Xw):
276-
"""Return ``w`` and ``Xw`` extrapolated."""
276+
"""Return w, Xw, and a bool indicating whether they were extrapolated."""
277277
if self.arr_w_ is None or self.arr_Xw_ is None:
278278
self.arr_w_ = np.zeros((w.shape[0], self.K+1))
279279
self.arr_Xw_ = np.zeros((Xw.shape[0], self.K+1))
@@ -282,19 +282,19 @@ def extrapolate(self, w, Xw):
282282
self.arr_w_[:, self.current_iter] = w
283283
self.arr_Xw_[:, self.current_iter] = Xw
284284
self.current_iter += 1
285-
return w, Xw
285+
return w, Xw, False
286286

287287
U = np.diff(self.arr_w_, axis=1) # compute residuals
288288

289289
# compute extrapolation coefs
290290
try:
291291
inv_UTU_ones = np.linalg.solve(U.T @ U, np.ones(self.K))
292292
except np.linalg.LinAlgError:
293-
return w, Xw
293+
return w, Xw, False
294294
finally:
295295
self.current_iter = 0
296296

297297
# extrapolate
298298
C = inv_UTU_ones / np.sum(inv_UTU_ones)
299299
# floating point errors may cause w and Xw to disagree
300-
return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C
300+
return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C, True

0 commit comments

Comments
 (0)