@@ -48,13 +48,24 @@ class ReHLine(_BaseReHLine, BaseEstimator):
4848 The intercept vector in the linear constraint.
4949
5050 verbose : int, default=0
51- Enable verbose output. Note that this setting takes advantage of a
52- per-process runtime setting in liblinear that, if enabled, may not work
53- properly in a multithreaded context.
51+ Enable verbose output.
5452
5553 max_iter : int, default=1000
5654 The maximum number of iterations to be run.
5755
56+ tol : float, default=1e-4
57+ The tolerance for the stopping criterion.
58+
59+ shrink : float, default=1
60+ The shrinkage of dual variables for the ReHLine algorithm.
61+
62+ warm_start : bool, default=False
63+ Whether to use the given dual params as an initial guess for the
64+ optimization algorithm.
65+
66+ trace_freq : int, default=100
67+ The frequency at which to print the optimization trace.
68+
5869 Attributes
5970 ----------
6071 coef\_ : array-like
@@ -72,6 +83,15 @@ class ReHLine(_BaseReHLine, BaseEstimator):
7283 primal_obj\_ : array-like
7384 The primal objective function values.
7485
86+ Lambda: array-like
87+ The optimized dual variables for ReLU parts.
88+
89+ Gamma: array-like
90+ The optimized dual variables for ReHU parts.
91+
92+ xi: array-like
93+ The optimized dual variables for linear constraints.
94+
7595 Examples
7696 --------
7797
@@ -109,7 +129,8 @@ def __init__(self, C=1.,
109129 Tau = np .empty (shape = (0 ,0 )),
110130 S = np .empty (shape = (0 ,0 )), T = np .empty (shape = (0 ,0 )),
111131 A = np .empty (shape = (0 ,0 )), b = np .empty (shape = (0 )),
112- max_iter = 1000 , tol = 1e-4 , shrink = 1 , verbose = 0 , trace_freq = 100 ):
132+ max_iter = 1000 , tol = 1e-4 , shrink = 1 , warm_start = 0 ,
133+ verbose = 0 , trace_freq = 100 ):
113134 self .C = C
114135 self .U = U
115136 self .V = V
@@ -124,8 +145,13 @@ def __init__(self, C=1.,
124145 self .max_iter = max_iter
125146 self .tol = tol
126147 self .shrink = shrink
148+ self .warm_start = warm_start
127149 self .verbose = verbose
128150 self .trace_freq = trace_freq
151+ self .Lambda = np .empty (shape = (0 , 0 ))
152+ self .Gamma = np .empty (shape = (0 , 0 ))
153+ self .xi = np .empty (shape = (0 , 0 ))
154+ self .coef_ = None
129155
130156 def fit (self , X , sample_weight = None ):
131157 """Fit the model based on the given training data.
@@ -147,41 +173,34 @@ def fit(self, X, sample_weight=None):
147173 An instance of the estimator.
148174 """
149175 # X = check_array(X)
176+ sample_weight = _check_sample_weight (sample_weight , X , dtype = X .dtype )
150177
151-
152- if sample_weight is None :
153- sample_weight = self .C
154- else :
155- sample_weight = self .C * _check_sample_weight (sample_weight , X , dtype = X .dtype )
156-
157- if self .L > 0 :
158- U_weight = self .U * sample_weight
159- V_weight = self .V * sample_weight
160- else :
161- U_weight = self .U
162- V_weight = self .V
163-
164- if self .H > 0 :
165- sqrt_sample_weight = np .sqrt (sample_weight )
166- Tau_weight = self .Tau * sqrt_sample_weight
167- S_weight = self .S * sqrt_sample_weight
168- T_weight = self .T * sqrt_sample_weight
169- else :
170- Tau_weight = self .Tau
171- S_weight = self .S
172- T_weight = self .T
178+ U_weight , V_weight , Tau_weight , S_weight , T_weight = self .cast_sample_weight (sample_weight = sample_weight )
179+
180+ if not self .warm_start :
181+ ## remove warm_start params
182+ self .Lambda = np .empty (shape = (0 , 0 ))
183+ self .Gamma = np .empty (shape = (0 , 0 ))
184+ self .xi = np .empty (shape = (0 , 0 ))
173185
174186 result = ReHLine_solver (X = X ,
175187 U = U_weight , V = V_weight ,
176188 Tau = Tau_weight ,
177189 S = S_weight , T = T_weight ,
178190 A = self .A , b = self .b ,
191+ Lambda = self .Lambda , Gamma = self .Gamma , xi = self .xi ,
179192 max_iter = self .max_iter , tol = self .tol ,
180193 shrink = self .shrink , verbose = self .verbose ,
181194 trace_freq = self .trace_freq )
182195
183- self .coef_ = result .beta
184196 self .opt_result_ = result
197+ # primal solution
198+ self .coef_ = result .beta
199+ # dual solution
200+ self .Lambda = result .Lambda
201+ self .Gamma = result .Gamma
202+ self .xi = result .xi
203+ # algo convergence
185204 self .n_iter_ = result .niter
186205 self .dual_obj_ = result .dual_objfns
187206 self .primal_obj_ = result .primal_objfns
@@ -191,6 +210,7 @@ def fit(self, X, sample_weight=None):
191210 "ReHLine failed to converge, increase the number of iterations: `max_iter`." ,
192211 ConvergenceWarning ,
193212 )
213+ return self
194214
195215 def decision_function (self , X ):
196216 """The decision function evaluated on the given dataset
@@ -304,7 +324,8 @@ def __init__(self, loss,
304324 Tau = np .empty (shape = (0 ,0 )),
305325 S = np .empty (shape = (0 ,0 )), T = np .empty (shape = (0 ,0 )),
306326 A = np .empty (shape = (0 ,0 )), b = np .empty (shape = (0 )),
307- max_iter = 1000 , tol = 1e-4 , shrink = 1 , verbose = 0 , trace_freq = 100 ):
327+ max_iter = 1000 , tol = 1e-4 , shrink = 1 , warm_start = 0 ,
328+ verbose = 0 , trace_freq = 100 ):
308329 self .loss = loss
309330 self .constraint = constraint
310331 self .C = C
@@ -321,9 +342,13 @@ def __init__(self, loss,
321342 self .max_iter = max_iter
322343 self .tol = tol
323344 self .shrink = shrink
345+ self .warm_start = warm_start
324346 self .verbose = verbose
325347 self .trace_freq = trace_freq
326- self .dummy_n = 0
348+ self .Lambda = np .empty (shape = (0 , 0 ))
349+ self .Gamma = np .empty (shape = (0 , 0 ))
350+ self .xi = np .empty (shape = (0 , 0 ))
351+ self .coef_ = None
327352
328353 def fit (self , X , y , sample_weight = None ):
329354 """Fit the model based on the given training data.
@@ -358,40 +383,34 @@ def fit(self, X, y, sample_weight=None):
358383 self .A , self .b = _make_constraint_rehline_param (constraint = self .constraint , X = X , y = y )
359384 self .auto_shape ()
360385
361- ## sample weight -> rehline params
362- if sample_weight is None :
363- sample_weight = self .C
364- else :
365- sample_weight = self .C * _check_sample_weight (sample_weight , X , dtype = X .dtype )
366-
367- if self .L > 0 :
368- U_weight = self .U * sample_weight
369- V_weight = self .V * sample_weight
370- else :
371- U_weight = self .U
372- V_weight = self .V
373-
374- if self .H > 0 :
375- sqrt_sample_weight = np .sqrt (sample_weight )
376- Tau_weight = self .Tau * sqrt_sample_weight
377- S_weight = self .S * sqrt_sample_weight
378- T_weight = self .T * sqrt_sample_weight
379- else :
380- Tau_weight = self .Tau
381- S_weight = self .S
382- T_weight = self .T
386+ sample_weight = _check_sample_weight (sample_weight , X , dtype = X .dtype )
387+
388+ U_weight , V_weight , Tau_weight , S_weight , T_weight = self .cast_sample_weight (sample_weight = sample_weight )
389+
390+ if not self .warm_start :
391+ ## remove warm_start params
392+ self .Lambda = np .empty (shape = (0 , 0 ))
393+ self .Gamma = np .empty (shape = (0 , 0 ))
394+ self .xi = np .empty (shape = (0 , 0 ))
383395
384396 result = ReHLine_solver (X = X ,
385397 U = U_weight , V = V_weight ,
386398 Tau = Tau_weight ,
387399 S = S_weight , T = T_weight ,
388400 A = self .A , b = self .b ,
401+ Lambda = self .Lambda , Gamma = self .Gamma , xi = self .xi ,
389402 max_iter = self .max_iter , tol = self .tol ,
390403 shrink = self .shrink , verbose = self .verbose ,
391404 trace_freq = self .trace_freq )
392405
393- self .coef_ = result .beta
394406 self .opt_result_ = result
407+ # primal solution
408+ self .coef_ = result .beta
409+ # dual solution
410+ self .Lambda = result .Lambda
411+ self .Gamma = result .Gamma
412+ self .xi = result .xi
413+ # algo convergence
395414 self .n_iter_ = result .niter
396415 self .dual_obj_ = result .dual_objfns
397416 self .primal_obj_ = result .primal_objfns
@@ -402,6 +421,8 @@ def fit(self, X, y, sample_weight=None):
402421 ConvergenceWarning ,
403422 )
404423
424+ return self
425+
405426 def decision_function (self , X ):
406427 """The decision function evaluated on the given dataset
407428
0 commit comments