Skip to content

Commit 1f13430

Browse files
authored
CLN Consistent namings across solvers (#38)
1 parent d5a7494 commit 1f13430

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

skglm/solvers/cd_solver.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,14 @@ def cd_solver(
323323
p_obj = datafit.value(y, w[ws], Xw) + penalty.value(w)
324324

325325
if is_sparse:
326-
grad = construct_grad_sparse(
326+
grad_ws = construct_grad_sparse(
327327
X.data, X.indptr, X.indices, y, w, Xw, datafit, ws)
328328
else:
329-
grad = construct_grad(X, y, w, Xw, datafit, ws)
329+
grad_ws = construct_grad(X, y, w, Xw, datafit, ws)
330330
if ws_strategy == "subdiff":
331-
opt_ws = penalty.subdiff_distance(w, grad, ws)
331+
opt_ws = penalty.subdiff_distance(w, grad_ws, ws)
332332
elif ws_strategy == "fixpoint":
333-
opt_ws = dist_fix_point(w, grad, datafit, penalty, ws)
333+
opt_ws = dist_fix_point(w, grad_ws, datafit, penalty, ws)
334334

335335
stop_crit_in = np.max(opt_ws)
336336
if max(verbose - 1, 0):
@@ -349,7 +349,7 @@ def cd_solver(
349349

350350

351351
@njit
352-
def _cd_epoch(X, y, w, Xw, datafit, penalty, feats):
352+
def _cd_epoch(X, y, w, Xw, datafit, penalty, ws):
353353
"""Run an epoch of coordinate descent in place.
354354
355355
Parameters
@@ -372,11 +372,11 @@ def _cd_epoch(X, y, w, Xw, datafit, penalty, feats):
372372
penalty : Penalty
373373
Penalty.
374374
375-
feats : array, shape (n_features,)
375+
ws : array, shape (ws_size,)
376376
The range of features.
377377
"""
378378
lc = datafit.lipschitz
379-
for j in feats:
379+
for j in ws:
380380
stepsize = 1/lc[j] if lc[j] != 0 else 1000
381381
Xj = X[:, j]
382382
old_w_j = w[j]
@@ -388,7 +388,7 @@ def _cd_epoch(X, y, w, Xw, datafit, penalty, feats):
388388

389389

390390
@njit
391-
def _cd_epoch_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, penalty, feats):
391+
def _cd_epoch_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, penalty, ws):
392392
"""Run an epoch of coordinate descent in place for a sparse CSC array.
393393
394394
Parameters
@@ -417,11 +417,11 @@ def _cd_epoch_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, penalty, fe
417417
penalty : Penalty
418418
Penalty.
419419
420-
feats : array, shape (n_features,)
421-
The range of features.
420+
ws : array, shape (ws_size,)
421+
The working set.
422422
"""
423423
lc = datafit.lipschitz
424-
for j in feats:
424+
for j in ws:
425425
stepsize = 1/lc[j] if lc[j] != 0 else 1000
426426

427427
old_w_j = w[j]

skglm/solvers/common.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,24 @@
33

44

55
@njit
6-
def dist_fix_point(w, grad, datafit, penalty, ws):
6+
def dist_fix_point(w, grad_ws, datafit, penalty, ws):
77
"""Compute the violation of the fixed point iterate scheme.
88
99
Parameters
1010
----------
1111
w : array, shape (n_features,)
1212
Coefficient vector.
1313
14-
grad : array, shape (n_features,)
15-
Gradient.
14+
grad_ws : array, shape (ws_size,)
15+
Gradient restricted to the working set.
1616
1717
datafit: instance of BaseDatafit
1818
Datafit.
1919
2020
penalty: instance of BasePenalty
2121
Penalty.
2222
23-
ws : array, shape (n_features,)
23+
ws : array, shape (ws_size,)
2424
The working set.
2525
2626
Returns
@@ -33,7 +33,7 @@ def dist_fix_point(w, grad, datafit, penalty, ws):
3333
lcj = datafit.lipschitz[j]
3434
if lcj != 0:
3535
dist_fix_point[idx] = np.abs(
36-
w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j))
36+
w[j] - penalty.prox_1d(w[j] - grad_ws[idx] / lcj, 1. / lcj, j))
3737
return dist_fix_point
3838

3939

@@ -58,7 +58,7 @@ def construct_grad(X, y, w, Xw, datafit, ws):
5858
datafit : Datafit
5959
Datafit.
6060
61-
ws : array, shape (n_features,)
61+
ws : array, shape (ws_size,)
6262
The working set.
6363
6464
Returns
@@ -99,7 +99,7 @@ def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws):
9999
datafit : Datafit
100100
Datafit.
101101
102-
ws : array, shape (n_features,)
102+
ws : array, shape (ws_size,)
103103
The working set.
104104
105105
Returns

skglm/solvers/group_bcd_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
107107

108108
if max(verbose - 1, 0):
109109
print(
110-
f"Epoch {epoch+1}: {p_obj:.10f} "
111-
f"obj. variation: {stop_crit_in:.2e}"
110+
f"Epoch {epoch + 1}, objective {p_obj:.10f}, "
111+
f"stopping crit {stop_crit_in:.2e}"
112112
)
113113

114114
if stop_crit_in <= 0.3 * stop_crit:

skglm/solvers/multitask_bcd_solver.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def bcd_solver_path(
8383
datafit.initialize_sparse(X.data, X.indptr, X.indices, Y)
8484
else:
8585
datafit.initialize(X, Y)
86-
n_samples, n_features = X.shape
86+
n_features = X.shape[1]
8787
n_tasks = Y.shape[1]
8888
if alphas is None:
8989
raise ValueError("alphas should be provided.")
@@ -285,15 +285,15 @@ def bcd_solver(
285285
p_obj = datafit.value(Y, W[ws, :], XW) + penalty.value(W)
286286

287287
if is_sparse:
288-
grad = construct_grad_sparse(
288+
grad_ws = construct_grad_sparse(
289289
X.data, X.indptr, X.indices, Y, XW, datafit, ws)
290290
else:
291-
grad = construct_grad(X, Y, W, XW, datafit, ws)
291+
grad_ws = construct_grad(X, Y, W, XW, datafit, ws)
292292

293293
if ws_strategy == "subdiff":
294-
opt_ws = penalty.subdiff_distance(W, grad, ws)
294+
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
295295
elif ws_strategy == "fixpoint":
296-
opt_ws = dist_fix_point(W, grad, datafit, penalty, ws)
296+
opt_ws = dist_fix_point(W, grad_ws, datafit, penalty, ws)
297297

298298
stop_crit_in = np.max(opt_ws)
299299
if max(verbose - 1, 0):
@@ -312,24 +312,24 @@ def bcd_solver(
312312

313313

314314
@njit
315-
def dist_fix_point(W, grad, datafit, penalty, ws):
315+
def dist_fix_point(W, grad_ws, datafit, penalty, ws):
316316
"""Compute the violation of the fixed point iterate schema.
317317
318318
Parameters
319319
----------
320320
W : array, shape (n_features, n_tasks)
321321
Coefficient matrix.
322322
323-
grad : array, shape (n_features, n_tasks)
324-
Gradient.
323+
grad_ws : array, shape (ws_size, n_tasks)
324+
Gradient restricted to the working set.
325325
326326
datafit: instance of BaseMultiTaskDatafit
327327
Datafit.
328328
329329
penalty: instance of BasePenalty
330330
Penalty.
331331
332-
ws : array, shape (n_features,)
332+
ws : array, shape (ws_size,)
333333
The working set.
334334
335335
Returns
@@ -342,7 +342,7 @@ def dist_fix_point(W, grad, datafit, penalty, ws):
342342
lcj = datafit.lipschitz[j]
343343
if lcj:
344344
dist_fix_point[idx] = norm(
345-
W[j] - penalty.prox_1feat(W[j] - grad[idx] / lcj, 1. / lcj, j))
345+
W[j] - penalty.prox_1feat(W[j] - grad_ws[idx] / lcj, 1. / lcj, j))
346346
return dist_fix_point
347347

348348

@@ -367,7 +367,7 @@ def construct_grad(X, Y, W, XW, datafit, ws):
367367
datafit : instance of BaseMultiTaskDatafit
368368
Datafit.
369369
370-
ws : array, shape (n_features,)
370+
ws : array, shape (ws_size,)
371371
The working set.
372372
373373
Returns
@@ -423,7 +423,7 @@ def construct_grad_sparse(data, indptr, indices, Y, XW, datafit, ws):
423423

424424

425425
@njit
426-
def _bcd_epoch(X, Y, W, XW, datafit, penalty, feats):
426+
def _bcd_epoch(X, Y, W, XW, datafit, penalty, ws):
427427
"""Run an epoch of block coordinate descent in place.
428428
429429
Parameters
@@ -446,12 +446,12 @@ def _bcd_epoch(X, Y, W, XW, datafit, penalty, feats):
446446
penalty : instance of BasePenalty
447447
Penalty.
448448
449-
feats : array, shape (ws_size,)
450-
Features to be updated.
449+
ws : array, shape (ws_size,)
450+
The working set.
451451
"""
452452
lc = datafit.lipschitz
453453
n_tasks = Y.shape[1]
454-
for j in feats:
454+
for j in ws:
455455
if lc[j] == 0.:
456456
continue
457457
Xj = X[:, j]
@@ -467,7 +467,7 @@ def _bcd_epoch(X, Y, W, XW, datafit, penalty, feats):
467467

468468

469469
@njit
470-
def _bcd_epoch_sparse(X_data, X_indptr, X_indices, Y, W, XW, datafit, penalty, feats):
470+
def _bcd_epoch_sparse(X_data, X_indptr, X_indices, Y, W, XW, datafit, penalty, ws):
471471
"""Run an epoch of block coordinate descent in place for a sparse CSC array.
472472
473473
Parameters
@@ -496,11 +496,11 @@ def _bcd_epoch_sparse(X_data, X_indptr, X_indices, Y, W, XW, datafit, penalty, f
496496
penalty : instance of BasePenalty
497497
Penalty.
498498
499-
feats : array, shape (ws_size,)
499+
ws : array, shape (ws_size,)
500500
Features to be updated.
501501
"""
502502
lc = datafit.lipschitz
503-
for j in feats:
503+
for j in ws:
504504
if lc[j] == 0.:
505505
continue
506506
old_W_j = W[j, :].copy()

0 commit comments

Comments
 (0)