Skip to content

Commit b80d7fa

Browse files
committed
refactor: rename loss_values to obj_values and remove plotting functionality for large dimensions
1 parent b4b47aa commit b80d7fa

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

rehline/_path_sol.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def plqERM_Ridge_path_sol(
9090
n_iters : list of int
9191
Number of iterations used by the solver at each regularization value.
9292
93-
loss_values : list of float
94-
Final loss values (including regularization term) at each `C`.
93+
obj_values : list of float
94+
Final objective values (including loss and regularization terms) at each `C`.
9595
9696
L2_norms : list of float
9797
L2 norm of the coefficients (excluding bias) at each `C`.
@@ -134,7 +134,7 @@ def plqERM_Ridge_path_sol(
134134
coefs = np.zeros((n_features, n_Cs))
135135
n_iters = []
136136
times = []
137-
loss_values = []
137+
obj_values = []
138138
L2_norms = []
139139

140140

@@ -150,7 +150,8 @@ def plqERM_Ridge_path_sol(
150150

151151
clf = plqERM_Ridge(
152152
loss=loss, constraint=constraint, C=Cs[0],
153-
max_iter=max_iter, tol=tol, shrink=shrink, verbose=verbose,
153+
max_iter=max_iter, tol=tol, shrink=shrink,
154+
verbose=1*(verbose>=2), # ben: if verbose is 1, then the fit function will not show the progress
154155
warm_start=warm_start
155156
)
156157

@@ -177,8 +178,8 @@ def plqERM_Ridge_path_sol(
177178
# Compute loss function parameters for ReHLoss
178179
l2_norm = np.linalg.norm(clf.coef_) ** 2
179180
score = clf.decision_function(X)
180-
total_loss = loss_obj(score) + 0.5*l2_norm
181-
loss_values.append(round(total_loss, 4))
181+
total_obj = loss_obj(score) + 0.5*l2_norm
182+
obj_values.append(round(total_obj, 4))
182183
L2_norms.append(round(np.linalg.norm(clf.coef_), 4))
183184

184185
# if warm_start:
@@ -203,7 +204,7 @@ def plqERM_Ridge_path_sol(
203204
print(f"{'C Value':<15}{'Iterations':<15}{'Time (s)':<20}{'Loss':<20}{'L2 Norm':<20}")
204205
print("-" * 90)
205206

206-
for C, iters, t, loss_val, l2 in zip(Cs, n_iters, times, loss_values, L2_norms):
207+
for C, iters, t, loss_val, l2 in zip(Cs, n_iters, times, obj_values, L2_norms):
207208
if return_time:
208209
print(f"{C:<15.4g}{iters:<15}{t:<20.6f}{loss_val:<20.6f}{l2:<20.6f}")
209210
else:
@@ -214,20 +215,22 @@ def plqERM_Ridge_path_sol(
214215
print(f"{'Avg Time/Iter':<12}{avg_time_per_iter:.6f} sec")
215216
print("=" * 90)
216217

217-
if verbose >= 2:
218-
import matplotlib.pyplot as plt
219-
plt.figure(figsize=(10, 6))
220-
for i in range(n_features):
221-
plt.plot(Cs, coefs[i, :], label=f'Feature {i+1}')
222-
plt.xscale('log')
223-
plt.xlabel('C')
224-
plt.ylabel('Coefficient Value')
225-
plt.title('Regularization Path')
226-
plt.legend()
227-
plt.show()
218+
# ben: remove the plot part, when d is large, the figure will be too large to show
219+
# if verbose >= 2:
220+
# # it's better to load the matplotlib.pyplot before the function
221+
# import matplotlib.pyplot as plt
222+
# plt.figure(figsize=(10, 6))
223+
# for i in range(n_features):
224+
# plt.plot(Cs, coefs[i, :], label=f'Feature {i+1}')
225+
# plt.xscale('log')
226+
# plt.xlabel('C')
227+
# plt.ylabel('Coefficient Value')
228+
# plt.title('Regularization Path')
229+
# plt.legend()
230+
# plt.show()
228231

229232
if return_time:
230-
return Cs, times, n_iters, loss_values, L2_norms, coefs
233+
return Cs, times, n_iters, obj_values, L2_norms, coefs
231234
else:
232-
return Cs, n_iters, loss_values, L2_norms, coefs
235+
return Cs, n_iters, obj_values, L2_norms, coefs
233236

0 commit comments

Comments
 (0)