|
2 | 2 | from abc import abstractmethod, ABC
|
3 | 3 |
|
4 | 4 | import numpy as np
|
| 5 | +from scipy.sparse import issparse |
5 | 6 |
|
6 | 7 | from skglm.utils.validation import check_attrs
|
7 | 8 | from skglm.utils.jit_compilation import compiled_clone
|
@@ -40,6 +41,8 @@ class BaseSolver(ABC):
|
40 | 41 | def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
|
41 | 42 | """Solve an optimization problem.
|
42 | 43 |
|
| 44 | + This method assumes that datafit was already initialized. |
| 45 | +
|
43 | 46 | Parameters
|
44 | 47 | ----------
|
45 | 48 | X : array, shape (n_samples, n_features)
|
@@ -95,7 +98,8 @@ def custom_checks(self, X, y, datafit, penalty):
|
95 | 98 | pass
|
96 | 99 |
|
97 | 100 | def solve(
|
98 |
| - self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True |
| 101 | + self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, |
| 102 | + run_checks=True, initialize_datafit=True |
99 | 103 | ):
|
100 | 104 | """Solve the optimization problem after validating its compatibility.
|
101 | 105 |
|
@@ -133,6 +137,13 @@ def solve(
|
133 | 137 | if run_checks:
|
134 | 138 | self._validate(X, y, datafit, penalty)
|
135 | 139 |
|
| 140 | + # check for None as GramCD solver doesn't take None as datafit |
| 141 | + if datafit is not None and initialize_datafit: |
| 142 | + if issparse(X): |
| 143 | + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) |
| 144 | + else: |
| 145 | + datafit.initialize(X, y) |
| 146 | + |
136 | 147 | return self._solve(X, y, datafit, penalty, w_init, Xw_init)
|
137 | 148 |
|
138 | 149 | def _validate(self, X, y, datafit, penalty):
|
|
0 commit comments