Skip to content

Commit cc690bd

Browse files
committed
init inside solve method
1 parent 5efeb8f commit cc690bd

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

skglm/solvers/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod, ABC
33

44
import numpy as np
5+
from scipy.sparse import issparse
56

67
from skglm.utils.validation import check_attrs
78
from skglm.utils.jit_compilation import compiled_clone
@@ -40,6 +41,8 @@ class BaseSolver(ABC):
4041
def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
4142
"""Solve an optimization problem.
4243
44+
This method assumes that datafit was already initialized.
45+
4346
Parameters
4447
----------
4548
X : array, shape (n_samples, n_features)
@@ -95,7 +98,8 @@ def custom_checks(self, X, y, datafit, penalty):
9598
pass
9699

97100
def solve(
98-
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True
101+
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *,
102+
run_checks=True, initialize_datafit=True
99103
):
100104
"""Solve the optimization problem after validating its compatibility.
101105
@@ -133,6 +137,13 @@ def solve(
133137
if run_checks:
134138
self._validate(X, y, datafit, penalty)
135139

140+
# check for None as GramCD solver doesn't take None as datafit
141+
if datafit is not None and initialize_datafit:
142+
if issparse(X):
143+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
144+
else:
145+
datafit.initialize(X, y)
146+
136147
return self._solve(X, y, datafit, penalty, w_init, Xw_init)
137148

138149
def _validate(self, X, y, datafit, penalty):

0 commit comments

Comments
 (0)