Skip to content

Commit 4030550

Browse files
committed
tests path sol
1 parent a2cbadc commit 4030550

File tree

3 files changed

+100
-20
lines changed

3 files changed

+100
-20
lines changed

rehline/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Import from internal C++ module
22
from ._base import (ReHLine_solver, _BaseReHLine,
33
_make_constraint_rehline_param, _make_loss_rehline_param)
4-
from ._class import ReHLine, plqERM_Ridge, CQR_Ridge
4+
from ._class import CQR_Ridge, ReHLine, plqERM_Ridge
55
from ._data import make_fair_classification
66
from ._internal import rehline_internal, rehline_result
7+
from ._path_sol import plqERM_Ridge_path_sol
78

89
__all__ = ("ReHLine_solver",
910
"_BaseReHLine",
1011
"ReHLine",
1112
"plqERM_Ridge",
1213
"CQR_Ridge",
14+
"plqERM_Ridge_path_sol",
1315
"_make_loss_rehline_param",
1416
"_make_constraint_rehline_param"
1517
"make_fair_classification")

rehline/_path_sol.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import numpy as np
21
import time
2+
33
import matplotlib.pyplot as plt
4-
from rehline import plqERM_Ridge
5-
from rehline import _make_loss_rehline_param
4+
import numpy as np
5+
6+
from ._base import _make_loss_rehline_param
7+
from ._class import plqERM_Ridge
68
from ._loss import ReHLoss
79

810

@@ -11,7 +13,7 @@ def plqERM_Ridge_path_sol(
1113
y,
1214
*,
1315
loss,
14-
constraint=[ ],
16+
constraint=[],
1517
eps=1e-3,
1618
n_Cs=100,
1719
Cs=None,
@@ -139,35 +141,47 @@ def plqERM_Ridge_path_sol(
139141
U, V, Tau, S, T = _make_loss_rehline_param(loss, X, y)
140142
loss_obj = ReHLoss(U, V, S, T, Tau)
141143

144+
# Lambda_ws = np.empty(shape=(0, 0))
145+
# Gamma_ws = np.empty(shape=(0, 0))
146+
# xi_ws = np.empty(shape=(0, 0))
147+
148+
clf = plqERM_Ridge(
149+
loss=loss, constraint=constraint, C=Cs[0],
150+
max_iter=max_iter, tol=tol, shrink=shrink, verbose=verbose,
151+
warm_start=warm_start
152+
)
153+
142154
for i, C in enumerate(Cs):
143155
if return_time:
144156
start_time = time.time()
145157

146-
clf = plqERM_Ridge(
147-
loss=loss, constraint=constraint, C=C,
148-
max_iter=max_iter, tol=tol, shrink=shrink, verbose=verbose,
149-
warm_start=warm_start
150-
)
158+
clf.C = C
159+
160+
# clf = plqERM_Ridge(
161+
# loss=loss, constraint=constraint, C=C,
162+
# max_iter=max_iter, tol=tol, shrink=shrink, verbose=verbose,
163+
# warm_start=warm_start
164+
# )
151165

152-
if warm_start and (i>0):
153-
clf.Lambda = Lambda
154-
clf.Gamma = Gamma
155-
clf.xi = xi
166+
# if (warm_start and (i>0)):
167+
# clf.Lambda = Lambda_ws
168+
# clf.Gamma = Gamma_ws
169+
# clf.xi = xi_ws
156170

157171
clf.fit(X, y)
158172
coefs[:, i] = clf.coef_
159173

160174
# Compute loss function parameters for ReHLoss
161-
l2_norm = 0.5 * np.linalg.norm(clf.coef_) ** 2
175+
l2_norm = np.linalg.norm(clf.coef_) ** 2
162176
score = clf.decision_function(X)
163-
total_loss = loss_obj(score) + l2_norm
177+
total_loss = loss_obj(score) + 0.5*l2_norm
164178
loss_values.append(round(total_loss, 4))
165179
L2_norms.append(round(np.linalg.norm(clf.coef_), 4))
166180

167-
if warm_start:
168-
Lambda = clf.Lambda
169-
Gamma = clf.Gamma
170-
xi = clf.xi
181+
# if warm_start:
182+
# Lambda_ws = clf.Lambda
183+
# Gamma_ws = clf.Gamma
184+
# xi_ws = clf.xi
171185

172186
if return_time:
173187
elapsed_time = time.time() - start_time

tests/_test_path_sol.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import time
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
## load datasets
6+
from sklearn.datasets import make_hastie_10_2
7+
8+
from rehline import plqERM_Ridge_path_sol
9+
10+
X, y = make_hastie_10_2()
11+
12+
# define loss function
13+
loss = {'name': 'svm'}
14+
Cs = np.logspace(-20, 10, 50, base=2) # Define a range of C values for the path
15+
16+
# calculate with warm_start=False
17+
# Cs_values_cold, times_cold, n_iters_cold, loss_values_cold, L2_norms_cold, coefs_cold = plqERM_Ridge_path_sol(
18+
# X, y,
19+
# loss=loss,
20+
# Cs=Cs,
21+
# max_iter=5000000,
22+
# tol=1e-4,
23+
# verbose=0,
24+
# warm_start=False,
25+
# constraint=[],
26+
# return_time=True,
27+
# plot_path=True
28+
# )
29+
30+
# calculate with warm_start=True
31+
Cs_values_warm, times_warm, n_iters_warm, loss_values_warm, L2_norms_warm, coefs_warm = plqERM_Ridge_path_sol(
32+
X, y,
33+
loss=loss,
34+
Cs=Cs,
35+
max_iter=1000000,
36+
tol=1e-4,
37+
verbose=1,
38+
warm_start=True,
39+
constraint=[],
40+
return_time=True,
41+
plot_path=True
42+
)
43+
44+
45+
# # Plot Cs vs times comparison
46+
# plt.figure(figsize=(10, 6))
47+
# plt.plot(Cs, times_warm, 'o-', label='Warm Start')
48+
# plt.plot(Cs, times_cold, 's-', label='Cold Start')
49+
# plt.xscale('log', base=2)
50+
# plt.xlabel('C values')
51+
# plt.ylabel('Time (seconds)')
52+
# plt.title('Computation Time vs. C Parameter')
53+
# plt.legend()
54+
# plt.grid(True)
55+
# plt.show()
56+
57+
58+
# # Print table comparing number of iterations
59+
# print("\nComparison of Number of Iterations:")
60+
# print("-" * 50)
61+
# print(f"{'C Value':^15} | {'Cold Start (iterations)':^20} | {'Warm Start (iterations)':^20}")
62+
# print("-" * 50)
63+
# for i, C in enumerate(Cs):
64+
# print(f"{C:^15.4f} | {n_iters_cold[i]:^20d} | {n_iters_warm[i]:^20d}")

0 commit comments

Comments
 (0)