Skip to content

Commit 1921e42

Browse files
committed
added init datafit in solver
1 parent 415c1e9 commit 1921e42

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

skglm/solvers/group_prox_newton.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6868
all_groups = np.arange(n_groups)
6969
stop_crit = 0.
7070
p_objs_out = []
71+
72+
# TODO: to be isolated in a seperated method
73+
is_sparse = issparse(X)
74+
if is_sparse:
75+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
76+
else:
77+
datafit.initialize(X, y)
7178

7279
for iter in range(self.max_iter):
7380
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)

skglm/solvers/prox_newton.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
8484
is_sparse = issparse(X)
8585
if is_sparse:
8686
X_bundles = (X.data, X.indptr, X.indices)
87+
88+
# TODO: to be isolated in a seperated method
89+
if is_sparse:
90+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
91+
else:
92+
datafit.initialize(X, y)
8793

8894
if self.ws_strategy == "fixpoint":
8995
X_square = X.multiply(X) if is_sparse else X ** 2

0 commit comments

Comments
 (0)