Skip to content

Commit f59c9cb

Browse files
committed
fix stepsizes & add comment
1 parent d25d9fd commit f59c9cb

File tree

2 files changed

+76
-42
lines changed

2 files changed

+76
-42
lines changed

debug.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# %%
12
import numpy as np
23
from skglm import GeneralizedLinearEstimator
34
from skglm.experimental.pdcd_ws import PDCD_WS
@@ -6,44 +7,64 @@
67
from sklearn.datasets import make_regression
78
from sklearn.preprocessing import StandardScaler
89
from skglm.utils.jit_compilation import compiled_clone
10+
from sklearn.linear_model import QuantileRegressor
911

1012

1113
def generate_dummy_data(n_samples=1000, n_features=10, noise=0.1):
1214
X, y = make_regression(n_samples=n_samples, n_features=n_features, noise=noise)
15+
# y -= y.mean()
16+
# y += 0.1
17+
y /= 10
1318
return X, y
1419

1520

1621
np.random.seed(42)
1722

18-
datafit = Pinball(0.5)
19-
penalty = L1(alpha=0.1)
23+
quantile_level = 0.5
24+
alpha = 0.1
25+
26+
X, y = generate_dummy_data(
27+
n_samples=1000, # if this is reduced to 100 samples, it converges
28+
n_features=11,
29+
)
30+
2031
solver = PDCD_WS(
32+
p0=11,
2133
max_iter=50,
2234
max_epochs=500,
23-
tol=1e-2,
35+
tol=1e-5,
2436
warm_start=False,
25-
verbose=1,
37+
verbose=2,
2638
)
2739

28-
# estimator = GeneralizedLinearEstimator(
29-
# datafit=datafit,
30-
# penalty=penalty,
31-
# solver=solver,
32-
# )
33-
34-
X, y = generate_dummy_data(
35-
n_samples=1000, # if this is reduced to 100 samples, it converges
36-
n_features=11,
37-
)
38-
# y -= y.mean()
39-
# y += 0.1
40-
y /= 10
41-
scaler = StandardScaler()
42-
X_scaled = scaler.fit_transform(X)
40+
datafit = Pinball(quantile_level)
41+
penalty = L1(alpha=alpha)
4342

4443
df = compiled_clone(datafit)
4544
pen = compiled_clone(penalty)
4645

4746
res = solver.solve(X, y, df, pen)
4847

49-
# estimator.fit(X, y)
48+
# %%
49+
50+
clf = QuantileRegressor(
51+
quantile=quantile_level,
52+
alpha=alpha/len(y),
53+
fit_intercept=False,
54+
solver='highs',
55+
).fit(X, y)
56+
57+
# %%
58+
print("diff solution:", np.linalg.norm((clf.coef_ - res[0])))
59+
60+
# %%
61+
62+
63+
def obj_val(w):
64+
return df.value(y, w, X @ w) + pen.value(w)
65+
66+
67+
for label, w in zip(("skglm", "sklearn"), (res[0], clf.coef_)):
68+
print(f"{label:10} {obj_val(w)=}")
69+
70+
# %%

skglm/experimental/pdcd_ws.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,22 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
9999
n_samples, n_features = X.shape
100100

101101
# init steps
102-
# Despite violating the conditions mentioned in [1]
103-
# this choice of steps yield in practice a convergent algorithm
104-
# with better speed of convergence
105-
dual_step = 1 / norm(X, ord=2)
106-
primal_steps = 1 / norm(X, axis=0, ord=2)
102+
# choose steps to verify condition: Assumption 2.1 e)
103+
scale = np.sqrt(2 * n_features)
104+
dual_steps = 1 / (norm(X, ord=2, axis=1) * scale)
105+
primal_steps = 1 / ((dual_steps[:, None] * (X ** 2)).sum(axis=0) * scale)
106+
107+
# NOTE: primal and dual steps verify condition on steps when multiplied/divided
108+
# by an arbitrary positive constant
109+
# HACK: balance primal and dual variable: take bigger steps
110+
# in the space with highest number of variable
111+
ratio = n_samples / n_features
112+
if n_samples > n_features:
113+
dual_steps *= ratio
114+
primal_steps /= ratio
115+
else:
116+
dual_steps /= ratio
117+
primal_steps *= ratio
107118

108119
# primal vars
109120
w = np.zeros(n_features) if w_init is None else w_init
@@ -125,7 +136,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
125136

126137
# check convergence using fixed-point criteria on both dual and primal
127138
opts_primal = _scores_primal(X, w, z, penalty, primal_steps, all_features)
128-
opt_dual = _score_dual(y, z, Xw, datafit, dual_step)
139+
opt_dual = _score_dual(y, z, Xw, datafit, dual_steps)
129140

130141
stop_crit = max(max(opts_primal), opt_dual)
131142

@@ -148,13 +159,9 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
148159

149160
# solve sub problem
150161
# inplace update of w, Xw, z, z_bar
151-
if iteration == 0:
152-
ep = 500
153-
else:
154-
ep = self.max_epochs
155162
PDCD_WS._solve_subproblem(
156163
y, X, w, Xw, z, z_bar, datafit, penalty,
157-
primal_steps, dual_step, ws, ep, tol_in=0.3*stop_crit, verbose=self.verbose-1)
164+
primal_steps, dual_steps, ws, self.max_epochs, tol_in=0.3*stop_crit, verbose=self.verbose-1)
158165

159166
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
160167
p_objs.append(current_p_obj)
@@ -172,7 +179,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
172179
@njit
173180
def _solve_subproblem(
174181
y, X, w, Xw, z, z_bar, datafit, penalty, primal_steps,
175-
dual_step, ws, max_epochs, tol_in, verbose):
182+
dual_steps, ws, max_epochs, tol_in, verbose):
176183
n_features = X.shape[1]
177184

178185
for epoch in range(max_epochs):
@@ -191,20 +198,26 @@ def _solve_subproblem(
191198
Xw += delta_w_j * X[:, j]
192199

193200
# update dual
194-
z_bar[:] = datafit.prox_conjugate(z + dual_step * Xw,
195-
dual_step, y)
201+
z_bar[:] = datafit.prox_conjugate(z + dual_steps * Xw,
202+
dual_steps, y)
196203
z += (z_bar - z) / n_features
197204

198205
# check convergence using fixed-point criteria on both dual and primal
199206
if epoch % 1 == 0:
200207
opts_primal_in = _scores_primal(X, w, z, penalty, primal_steps, ws)
201-
opt_dual_in = _score_dual(y, z, Xw, datafit, dual_step)
208+
opt_dual_in = _score_dual(y, z, Xw, datafit, dual_steps)
202209

203210
stop_crit_in = max(max(opts_primal_in), opt_dual_in)
204-
if verbose:
205-
print(f' epoch {epoch}, inner stopping crit: ', stop_crit_in)
206-
print(opt_dual_in)
207-
print(opts_primal_in)
211+
# if verbose:
212+
# current_p_obj = datafit.value(y, w, X@w) + penalty.value(w)
213+
# print(
214+
# f"|----- epoch {epoch+1}: {current_p_obj:.10f}, "
215+
# f"opt primal: {max(opts_primal_in):.2e}, opt dual: {opt_dual_in:.2e}")
216+
217+
# print(f' epoch {epoch}, inner stopping crit: ', stop_crit_in)
218+
# # print(opt_dual_in)
219+
# # print(opts_primal_in)
220+
208221
if stop_crit_in <= tol_in:
209222
break
210223

@@ -228,7 +241,7 @@ def _scores_primal(X, w, z, penalty, primal_steps, ws):
228241

229242

230243
@njit
231-
def _score_dual(y, z, Xw, datafit, dual_step):
232-
next_z = datafit.prox_conjugate(z + dual_step * Xw,
233-
dual_step, y)
244+
def _score_dual(y, z, Xw, datafit, dual_steps):
245+
next_z = datafit.prox_conjugate(z + dual_steps * Xw,
246+
dual_steps, y)
234247
return norm(z - next_z, ord=np.inf)

0 commit comments

Comments
 (0)