Skip to content

Commit 52c8319

Browse files
MNT - migrate sparse matrix operation to utils.sparse_ops (#176)
Co-authored-by: jasper <[email protected]>
1 parent 18e6456 commit 52c8319

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

skglm/datafits/single_task.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from numba import float64, int64, bool_
55

66
from skglm.datafits.base import BaseDatafit
7-
from skglm.utils.sparse_ops import spectral_norm
8-
from skglm.solvers.prox_newton import _sparse_xj_dot
7+
from skglm.utils.sparse_ops import spectral_norm, _sparse_xj_dot
98

109

1110
class Quadratic(BaseDatafit):

skglm/solvers/prox_newton.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from skglm.solvers.base import BaseSolver
77

88
from sklearn.exceptions import ConvergenceWarning
9-
9+
from skglm.utils.sparse_ops import _sparse_xj_dot
1010

1111
EPS_TOL = 0.3
1212
MAX_CD_ITER = 20
@@ -413,15 +413,6 @@ def _construct_grad_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, ws):
413413
return grad
414414

415415

416-
@njit(fastmath=True)
417-
def _sparse_xj_dot(X_data, X_indptr, X_indices, j, other):
418-
# Compute X[:, j] @ other in case X sparse
419-
res = 0.
420-
for i in range(X_indptr[j], X_indptr[j+1]):
421-
res += X_data[i] * other[X_indices[i]]
422-
return res
423-
424-
425416
@njit(fastmath=True)
426417
def _sparse_weighted_dot(X_data, X_indptr, X_indices, j, other, weights):
427418
# Compute X[:, j] @ (weights * other) in case X sparse

skglm/utils/sparse_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,12 @@ def _XT_dot_vec(X_data, X_indptr, X_indices, vec):
9292
result[j] += X_data[idx] * vec[X_indices[idx]]
9393

9494
return result
95+
96+
97+
@njit(fastmath=True)
98+
def _sparse_xj_dot(X_data, X_indptr, X_indices, j, other):
99+
# Compute X[:, j] @ other in case X sparse
100+
res = 0.
101+
for i in range(X_indptr[j], X_indptr[j+1]):
102+
res += X_data[i] * other[X_indices[i]]
103+
return res

0 commit comments

Comments
 (0)