1- import numpy as np
21import time
2+
33import matplotlib .pyplot as plt
4- from rehline import plqERM_Ridge
5- from rehline import _make_loss_rehline_param
4+ import numpy as np
5+
6+ from ._base import _make_loss_rehline_param
7+ from ._class import plqERM_Ridge
68from ._loss import ReHLoss
79
810
@@ -11,7 +13,7 @@ def plqERM_Ridge_path_sol(
1113 y ,
1214 * ,
1315 loss ,
14- constraint = [ ],
16+ constraint = [],
1517 eps = 1e-3 ,
1618 n_Cs = 100 ,
1719 Cs = None ,
@@ -139,35 +141,47 @@ def plqERM_Ridge_path_sol(
139141 U , V , Tau , S , T = _make_loss_rehline_param (loss , X , y )
140142 loss_obj = ReHLoss (U , V , S , T , Tau )
141143
144+ # Lambda_ws = np.empty(shape=(0, 0))
145+ # Gamma_ws = np.empty(shape=(0, 0))
146+ # xi_ws = np.empty(shape=(0, 0))
147+
148+ clf = plqERM_Ridge (
149+ loss = loss , constraint = constraint , C = Cs [0 ],
150+ max_iter = max_iter , tol = tol , shrink = shrink , verbose = verbose ,
151+ warm_start = warm_start
152+ )
153+
142154 for i , C in enumerate (Cs ):
143155 if return_time :
144156 start_time = time .time ()
145157
146- clf = plqERM_Ridge (
147- loss = loss , constraint = constraint , C = C ,
148- max_iter = max_iter , tol = tol , shrink = shrink , verbose = verbose ,
149- warm_start = warm_start
150- )
158+ clf .C = C
159+
160+ # clf = plqERM_Ridge(
161+ # loss=loss, constraint=constraint, C=C,
162+ # max_iter=max_iter, tol=tol, shrink=shrink, verbose=verbose,
163+ # warm_start=warm_start
164+ # )
151165
152- if warm_start and (i > 0 ):
153- clf .Lambda = Lambda
154- clf .Gamma = Gamma
155- clf .xi = xi
166+ # if ( warm_start and (i>0) ):
167+ # clf.Lambda = Lambda_ws
168+ # clf.Gamma = Gamma_ws
169+ # clf.xi = xi_ws
156170
157171 clf .fit (X , y )
158172 coefs [:, i ] = clf .coef_
159173
160174 # Compute loss function parameters for ReHLoss
161- l2_norm = 0.5 * np .linalg .norm (clf .coef_ ) ** 2
175+ l2_norm = np .linalg .norm (clf .coef_ ) ** 2
162176 score = clf .decision_function (X )
163- total_loss = loss_obj (score ) + l2_norm
177+ total_loss = loss_obj (score ) + 0.5 * l2_norm
164178 loss_values .append (round (total_loss , 4 ))
165179 L2_norms .append (round (np .linalg .norm (clf .coef_ ), 4 ))
166180
167- if warm_start :
168- Lambda = clf .Lambda
169- Gamma = clf .Gamma
170- xi = clf .xi
181+ # if warm_start:
182+ # Lambda_ws = clf.Lambda
183+ # Gamma_ws = clf.Gamma
184+ # xi_ws = clf.xi
171185
172186 if return_time :
173187 elapsed_time = time .time () - start_time
0 commit comments