Skip to content

Commit 1d4de0f

Browse files
authored
FIX gradient sign convention (#98)
1 parent 238ce8a commit 1d4de0f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

skglm/penalties/block_separable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def prox_1group(self, value, stepsize, g):
286286
def subdiff_distance(self, w, grad_ws, ws):
287287
"""Compute distance to the subdifferential at ``w`` of negative gradient.
288288
289-
Note: ``grad_ws`` is a stacked array of ``-``gradients.
290-
([-grad_ws_1, -grad_ws_2, ...])
289+
Note: ``grad_ws`` is a stacked array of gradients.
290+
([grad_ws_1, grad_ws_2, ...])
291291
"""
292292
alpha, weights = self.alpha, self.weights
293293
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
@@ -307,7 +307,7 @@ def subdiff_distance(self, w, grad_ws, ws):
307307
scores[idx] = max(0, norm(grad_g) - alpha * weights[g])
308308
else:
309309
subdiff = alpha * weights[g] * w_g / norm_w_g
310-
scores[idx] = norm(grad_g - subdiff)
310+
scores[idx] = norm(grad_g + subdiff)
311311

312312
return scores
313313

skglm/solvers/group_bcd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,6 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
170170
grad_ptr = 0
171171
for g in ws:
172172
grad_g = datafit.gradient_g(X, y, w, Xw, g)
173-
grads[grad_ptr: grad_ptr+len(grad_g)] = -grad_g
173+
grads[grad_ptr: grad_ptr+len(grad_g)] = grad_g
174174
grad_ptr += len(grad_g)
175175
return grads

0 commit comments

Comments
 (0)