Skip to content

Commit 7e11380

Browse files
PABanniermathurinmBadr-MOUFAD
authored
FIX Remove Numba warning when fitting multi-task estimators (#93)
Co-authored-by: mathurinm <[email protected]> Co-authored-by: Badr MOUFAD <[email protected]>
1 parent 936a3a0 commit 7e11380

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

skglm/datafits/multi_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def value(self, Y, W, XW):
6868
def gradient_j(self, X, Y, W, XW, j):
6969
"""Gradient with respect to j-th coordinate of W."""
7070
n_samples = X.shape[0]
71-
return (X[:, j:j+1].T @ XW - self.XtY[j, :]) / n_samples
71+
return (X[:, j] @ XW - self.XtY[j, :]) / n_samples
7272

7373
def gradient_j_sparse(self, X_data, X_indptr, X_indices, Y, XW, j):
7474
"""Gradient with respect to j-th coordinate of W when X is sparse."""

skglm/solvers/multitask_bcd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def _bcd_epoch(X, Y, W, XW, datafit, penalty, ws):
369369
continue
370370
Xj = X[:, j]
371371
old_W_j = W[j, :].copy() # copy is very important here
372-
W[j:j+1, :] = penalty.prox_1feat(
373-
W[j:j+1, :] - datafit.gradient_j(X, Y, W, XW, j) / lc[j],
372+
W[j, :] = penalty.prox_1feat(
373+
W[j, :] - datafit.gradient_j(X, Y, W, XW, j) / lc[j],
374374
1 / lc[j], j)
375375
if not np.all(W[j, :] == old_W_j):
376376
for k in range(n_tasks):

0 commit comments

Comments
 (0)