Skip to content

Commit 77f2b1a

Browse files
committed
add warmstart
1 parent 48ae8b4 commit 77f2b1a

File tree

4 files changed

+213
-57
lines changed

4 files changed

+213
-57
lines changed

rehline/_base.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,62 @@ def auto_shape(self):
7373
self.H = self.S.shape[0]
7474
self.K = self.A.shape[0]
7575

76+
def cast_sample_weight(self, sample_weight=None):
77+
"""
78+
Cast the sample weight to the ReHLine parameters.
79+
80+
Parameters
81+
----------
82+
sample_weight : array-like of shape (n_samples,), default=None
83+
Sample weights. If None, then samples are equally weighted.
84+
85+
Returns
86+
-------
87+
U_weight : array-like of shape (L, n_samples)
88+
Weighted ReLU coefficient matrix.
89+
90+
V_weight : array-like of shape (L, n_samples)
91+
Weighted ReLU intercept vector.
92+
93+
Tau_weight : array-like of shape (H, n_samples)
94+
Weighted ReHU cutpoint matrix.
95+
96+
S_weight : array-like of shape (H, n_samples)
97+
Weighted ReHU coefficient vector.
98+
99+
T_weight : array-like of shape (H, n_samples)
100+
Weighted ReHU intercept vector.
101+
102+
Notes
103+
-----
104+
This method casts the sample weight to the ReHLine parameters by multiplying
105+
the sample weight with the ReLU and ReHU parameters. If sample_weight is None,
106+
then the sample weight is set to the weight parameter C.
107+
"""
108+
109+
self.auto_shape()
110+
111+
sample_weight = self.C*sample_weight
112+
113+
if self.L > 0:
114+
U_weight = self.U * sample_weight
115+
V_weight = self.V * sample_weight
116+
else:
117+
U_weight = self.U
118+
V_weight = self.V
119+
120+
if self.H > 0:
121+
sqrt_sample_weight = np.sqrt(sample_weight)
122+
Tau_weight = self.Tau * sqrt_sample_weight
123+
S_weight = self.S * sqrt_sample_weight
124+
T_weight = self.T * sqrt_sample_weight
125+
else:
126+
Tau_weight = self.Tau
127+
S_weight = self.S
128+
T_weight = self.T
129+
130+
return U_weight, V_weight, Tau_weight, S_weight, T_weight
131+
76132
def call_ReLHLoss(self, score):
77133
"""
78134
Return the value of the ReHLine loss of the `score`.
@@ -172,12 +228,22 @@ def _check_rehu(rehu_coef, rehu_intercept, rehu_cut):
172228
if len(rehu_coef) > 0:
173229
assert (rehu_cut >= 0.0).all(), "`rehu_cut` must be non-negative!"
174230

231+
175232
def ReHLine_solver(X, U, V,
176233
Tau=np.empty(shape=(0, 0)),
177234
S=np.empty(shape=(0, 0)), T=np.empty(shape=(0, 0)),
178235
A=np.empty(shape=(0, 0)), b=np.empty(shape=(0)),
236+
Lambda=np.empty(shape=(0, 0)),
237+
Gamma=np.empty(shape=(0, 0)),
238+
xi=np.empty(shape=(0, 0)),
179239
max_iter=1000, tol=1e-4, shrink=1, verbose=1, trace_freq=100):
180240
result = rehline_result()
241+
if len(Lambda)>0:
242+
result.Lambda = Lambda
243+
if len(Gamma)>0:
244+
result.Gamma = Gamma
245+
if len(xi)>0:
246+
result.xi = xi
181247
rehline_internal(result, X, A, b, U, V, S, T, Tau, max_iter, tol, shrink, verbose, trace_freq)
182248
return result
183249

rehline/_class.py

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,24 @@ class ReHLine(_BaseReHLine, BaseEstimator):
4848
The intercept vector in the linear constraint.
4949
5050
verbose : int, default=0
51-
Enable verbose output. Note that this setting takes advantage of a
52-
per-process runtime setting in liblinear that, if enabled, may not work
53-
properly in a multithreaded context.
51+
Enable verbose output.
5452
5553
max_iter : int, default=1000
5654
The maximum number of iterations to be run.
5755
56+
tol : float, default=1e-4
57+
The tolerance for the stopping criterion.
58+
59+
shrink : float, default=1
60+
The shrinkage of dual variables for the ReHLine algorithm.
61+
62+
warm_start : bool, default=False
63+
Whether to use the given dual params as an initial guess for the
64+
optimization algorithm.
65+
66+
trace_freq : int, default=100
67+
The frequency at which to print the optimization trace.
68+
5869
Attributes
5970
----------
6071
coef\_ : array-like
@@ -72,6 +83,15 @@ class ReHLine(_BaseReHLine, BaseEstimator):
7283
primal_obj\_ : array-like
7384
The primal objective function values.
7485
86+
Lambda: array-like
87+
The optimized dual variables for ReLU parts.
88+
89+
Gamma: array-like
90+
The optimized dual variables for ReHU parts.
91+
92+
xi: array-like
93+
The optimized dual variables for linear constraints.
94+
7595
Examples
7696
--------
7797
@@ -109,7 +129,8 @@ def __init__(self, C=1.,
109129
Tau=np.empty(shape=(0,0)),
110130
S=np.empty(shape=(0,0)), T=np.empty(shape=(0,0)),
111131
A=np.empty(shape=(0,0)), b=np.empty(shape=(0)),
112-
max_iter=1000, tol=1e-4, shrink=1, verbose=0, trace_freq=100):
132+
max_iter=1000, tol=1e-4, shrink=1, warm_start=0,
133+
verbose=0, trace_freq=100):
113134
self.C = C
114135
self.U = U
115136
self.V = V
@@ -124,8 +145,13 @@ def __init__(self, C=1.,
124145
self.max_iter = max_iter
125146
self.tol = tol
126147
self.shrink = shrink
148+
self.warm_start = warm_start
127149
self.verbose = verbose
128150
self.trace_freq = trace_freq
151+
self.Lambda = np.empty(shape=(0, 0))
152+
self.Gamma = np.empty(shape=(0, 0))
153+
self.xi = np.empty(shape=(0, 0))
154+
self.coef_ = None
129155

130156
def fit(self, X, sample_weight=None):
131157
"""Fit the model based on the given training data.
@@ -147,41 +173,34 @@ def fit(self, X, sample_weight=None):
147173
An instance of the estimator.
148174
"""
149175
# X = check_array(X)
176+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
150177

151-
152-
if sample_weight is None:
153-
sample_weight = self.C
154-
else:
155-
sample_weight = self.C*_check_sample_weight(sample_weight, X, dtype=X.dtype)
156-
157-
if self.L > 0:
158-
U_weight = self.U * sample_weight
159-
V_weight = self.V * sample_weight
160-
else:
161-
U_weight = self.U
162-
V_weight = self.V
163-
164-
if self.H > 0:
165-
sqrt_sample_weight = np.sqrt(sample_weight)
166-
Tau_weight = self.Tau * sqrt_sample_weight
167-
S_weight = self.S * sqrt_sample_weight
168-
T_weight = self.T * sqrt_sample_weight
169-
else:
170-
Tau_weight = self.Tau
171-
S_weight = self.S
172-
T_weight = self.T
178+
U_weight, V_weight, Tau_weight, S_weight, T_weight = self.cast_sample_weight(sample_weight=sample_weight)
179+
180+
if not self.warm_start:
181+
## remove warm_start params
182+
self.Lambda = np.empty(shape=(0, 0))
183+
self.Gamma = np.empty(shape=(0, 0))
184+
self.xi = np.empty(shape=(0, 0))
173185

174186
result = ReHLine_solver(X=X,
175187
U=U_weight, V=V_weight,
176188
Tau=Tau_weight,
177189
S=S_weight, T=T_weight,
178190
A=self.A, b=self.b,
191+
Lambda=self.Lambda, Gamma=self.Gamma, xi=self.xi,
179192
max_iter=self.max_iter, tol=self.tol,
180193
shrink=self.shrink, verbose=self.verbose,
181194
trace_freq=self.trace_freq)
182195

183-
self.coef_ = result.beta
184196
self.opt_result_ = result
197+
# primal solution
198+
self.coef_ = result.beta
199+
# dual solution
200+
self.Lambda = result.Lambda
201+
self.Gamma = result.Gamma
202+
self.xi = result.xi
203+
# algo convergence
185204
self.n_iter_ = result.niter
186205
self.dual_obj_ = result.dual_objfns
187206
self.primal_obj_ = result.primal_objfns
@@ -191,6 +210,7 @@ def fit(self, X, sample_weight=None):
191210
"ReHLine failed to converge, increase the number of iterations: `max_iter`.",
192211
ConvergenceWarning,
193212
)
213+
return self
194214

195215
def decision_function(self, X):
196216
"""The decision function evaluated on the given dataset
@@ -304,7 +324,8 @@ def __init__(self, loss,
304324
Tau=np.empty(shape=(0,0)),
305325
S=np.empty(shape=(0,0)), T=np.empty(shape=(0,0)),
306326
A=np.empty(shape=(0,0)), b=np.empty(shape=(0)),
307-
max_iter=1000, tol=1e-4, shrink=1, verbose=0, trace_freq=100):
327+
max_iter=1000, tol=1e-4, shrink=1, warm_start=0,
328+
verbose=0, trace_freq=100):
308329
self.loss = loss
309330
self.constraint = constraint
310331
self.C = C
@@ -321,9 +342,13 @@ def __init__(self, loss,
321342
self.max_iter = max_iter
322343
self.tol = tol
323344
self.shrink = shrink
345+
self.warm_start = warm_start
324346
self.verbose = verbose
325347
self.trace_freq = trace_freq
326-
self.dummy_n = 0
348+
self.Lambda = np.empty(shape=(0, 0))
349+
self.Gamma = np.empty(shape=(0, 0))
350+
self.xi = np.empty(shape=(0, 0))
351+
self.coef_ = None
327352

328353
def fit(self, X, y, sample_weight=None):
329354
"""Fit the model based on the given training data.
@@ -358,40 +383,34 @@ def fit(self, X, y, sample_weight=None):
358383
self.A, self.b = _make_constraint_rehline_param(constraint=self.constraint, X=X, y=y)
359384
self.auto_shape()
360385

361-
## sample weight -> rehline params
362-
if sample_weight is None:
363-
sample_weight = self.C
364-
else:
365-
sample_weight = self.C*_check_sample_weight(sample_weight, X, dtype=X.dtype)
366-
367-
if self.L > 0:
368-
U_weight = self.U * sample_weight
369-
V_weight = self.V * sample_weight
370-
else:
371-
U_weight = self.U
372-
V_weight = self.V
373-
374-
if self.H > 0:
375-
sqrt_sample_weight = np.sqrt(sample_weight)
376-
Tau_weight = self.Tau * sqrt_sample_weight
377-
S_weight = self.S * sqrt_sample_weight
378-
T_weight = self.T * sqrt_sample_weight
379-
else:
380-
Tau_weight = self.Tau
381-
S_weight = self.S
382-
T_weight = self.T
386+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
387+
388+
U_weight, V_weight, Tau_weight, S_weight, T_weight = self.cast_sample_weight(sample_weight=sample_weight)
389+
390+
if not self.warm_start:
391+
## remove warm_start params
392+
self.Lambda = np.empty(shape=(0, 0))
393+
self.Gamma = np.empty(shape=(0, 0))
394+
self.xi = np.empty(shape=(0, 0))
383395

384396
result = ReHLine_solver(X=X,
385397
U=U_weight, V=V_weight,
386398
Tau=Tau_weight,
387399
S=S_weight, T=T_weight,
388400
A=self.A, b=self.b,
401+
Lambda=self.Lambda, Gamma=self.Gamma, xi=self.xi,
389402
max_iter=self.max_iter, tol=self.tol,
390403
shrink=self.shrink, verbose=self.verbose,
391404
trace_freq=self.trace_freq)
392405

393-
self.coef_ = result.beta
394406
self.opt_result_ = result
407+
# primal solution
408+
self.coef_ = result.beta
409+
# dual solution
410+
self.Lambda = result.Lambda
411+
self.Gamma = result.Gamma
412+
self.xi = result.xi
413+
# algo convergence
395414
self.n_iter_ = result.niter
396415
self.dual_obj_ = result.dual_objfns
397416
self.primal_obj_ = result.primal_objfns
@@ -402,6 +421,8 @@ def fit(self, X, y, sample_weight=None):
402421
ConvergenceWarning,
403422
)
404423

424+
return self
425+
405426
def decision_function(self, X):
406427
"""The decision function evaluated on the given dataset
407428

tests/_test_svm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## Test SVM on simulated dataset
22
import numpy as np
3+
34
from rehline import ReHLine
45

56
np.random.seed(1024)
@@ -11,18 +12,14 @@
1112

1213
## solution provided by sklearn
1314
from sklearn.svm import LinearSVC
15+
1416
clf = LinearSVC(C=C, loss='hinge', fit_intercept=False,
1517
random_state=0, tol=1e-6, max_iter=1000000)
1618
clf.fit(X, y)
1719
sol = clf.coef_.flatten()
1820

1921
print('solution privided by liblinear: %s' %sol)
2022

21-
## solution provided by ReHLine
22-
# build-in loss
23-
clf = ReHLine(loss={'name': 'svm'}, C=C)
24-
clf.make_ReLHLoss(X=X, y=y, loss={'name': 'svm'})
25-
clf.fit(X=X)
2623

2724
print('solution privided by rehline: %s' %clf.coef_)
2825
print(clf.decision_function([[.1,.2,.3]]))

0 commit comments

Comments
 (0)