@@ -38,6 +38,13 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False):
3838
3939 def _solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
4040
41+ # TODO: to be isolated in a seperated method
42+ is_sparse = issparse (X )
43+ if is_sparse :
44+ datafit .initialize_sparse (X .data , X .indptr , X .indices , y )
45+ else :
46+ datafit .initialize (X , y )
47+
4148 def objective (w ):
4249 Xw = X @ w
4350 datafit_value = datafit .value (y , w , Xw )
@@ -70,8 +77,7 @@ def callback_post_iter(w_k):
7077
7178 it = len (p_objs_out )
7279 print (
73- f"Iteration { it } : { p_obj :.10f} , "
74- f"stopping crit: { stop_crit :.2e} "
80+ f"Iteration { it } : { p_obj :.10f} , " f"stopping crit: { stop_crit :.2e} "
7581 )
7682
7783 n_features = X .shape [1 ]
@@ -87,7 +93,7 @@ def callback_post_iter(w_k):
8793 options = dict (
8894 maxiter = self .max_iter ,
8995 gtol = self .tol ,
90- ftol = 0. # set ftol=0. to control convergence using only gtol
96+ ftol = 0.0 , # set ftol=0. to control convergence using only gtol
9197 ),
9298 callback = callback_post_iter ,
9399 )
@@ -97,7 +103,7 @@ def callback_post_iter(w_k):
97103 f"`LBFGS` did not converge for tol={ self .tol :.3e} "
98104 f"and max_iter={ self .max_iter } .\n "
99105 "Consider increasing `max_iter` and/or `tol`." ,
100- category = ConvergenceWarning
106+ category = ConvergenceWarning ,
101107 )
102108
103109 w = result .x
@@ -110,7 +116,8 @@ def callback_post_iter(w_k):
110116 def custom_checks (self , X , y , datafit , penalty ):
111117 # check datafit support sparse data
112118 check_attrs (
113- datafit , solver = self ,
119+ datafit ,
120+ solver = self ,
114121 required_attr = self ._datafit_required_attr ,
115- support_sparse = issparse (X )
122+ support_sparse = issparse (X ),
116123 )
0 commit comments