Skip to content

Commit d0536dc

Browse files
authored
MNT - remove normalization by sqrt(n_samples) in Square root Lasso (#130)
1 parent 9fe4bae commit d0536dc

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212

1313

1414
class SqrtQuadratic(BaseDatafit):
15-
"""Square root quadratic datafit.
15+
"""Unnormalized square root quadratic datafit.
1616
1717
The datafit reads::
18-
||y - Xw||_2 / sqrt(n_samples)
18+
19+
||y - Xw||_2
1920
"""
2021

2122
def __init__(self):
@@ -29,7 +30,7 @@ def params_to_dict(self):
2930
return dict()
3031

3132
def value(self, y, w, Xw):
32-
return np.linalg.norm(y - Xw) / np.sqrt(len(y))
33+
return np.linalg.norm(y - Xw)
3334

3435
def raw_grad(self, y, Xw):
3536
"""Compute gradient of datafit w.r.t ``Xw``.
@@ -45,12 +46,12 @@ def raw_grad(self, y, Xw):
4546
if norm_residuals < 1e-2 * norm(y):
4647
raise ValueError("SmallResidualException")
4748

48-
return minus_residual / (norm_residuals * np.sqrt(len(y)))
49+
return minus_residual / norm_residuals
4950

5051
def raw_hessian(self, y, Xw):
5152
"""Diagonal matrix upper bounding the Hessian."""
5253
n_samples = len(y)
53-
fill_value = 1 / (np.sqrt(n_samples) * norm(y - Xw))
54+
fill_value = 1 / norm(y - Xw)
5455
return np.full(n_samples, fill_value)
5556

5657

@@ -59,7 +60,7 @@ class SqrtLasso(LinearModel, RegressorMixin):
5960
6061
The optimization objective for square root Lasso is::
6162
62-
|y - X w||_2 / sqrt(n_samples) + alpha * ||w||_1
63+
|y - X w||_2 + alpha * ||w||_1
6364
6465
Parameters
6566
----------
@@ -205,7 +206,8 @@ def _chambolle_pock_sqrt(X, y, alpha, max_iter=1000, obj_freq=10, verbose=False)
205206
"""Apply Chambolle-Pock algorithm to solve square-root Lasso.
206207
207208
The objective function is:
208-
min_w ||Xw - y||_2/sqrt(n_samples) + alpha * ||w||_1.
209+
210+
min_w ||Xw - y||_2 + alpha * ||w||_1.
209211
"""
210212
n_samples, n_features = X.shape
211213
# dual variable is z, primal is w
@@ -221,12 +223,12 @@ def _chambolle_pock_sqrt(X, y, alpha, max_iter=1000, obj_freq=10, verbose=False)
221223
sigma = 0.99 / L
222224

223225
for t in range(max_iter):
224-
w = ST_vec(w - tau * X.T @ (2 * z - z_old), alpha * np.sqrt(n_samples) * tau)
226+
w = ST_vec(w - tau * X.T @ (2 * z - z_old), alpha * tau)
225227
z_old = z.copy()
226228
z[:] = proj_L2ball(z + sigma * (X @ w - y))
227229

228230
if t % obj_freq == 0:
229-
objs.append(norm(X @ w - y) / np.sqrt(n_samples) + alpha * norm(w, ord=1))
231+
objs.append(norm(X @ w - y) + alpha * norm(w, ord=1))
230232
if verbose:
231233
print(f"Iter {t}, obj {objs[-1]: .10f}")
232234

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def test_alpha_max():
1010
n_samples, n_features = 50, 10
1111
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
12-
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
12+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
1313

1414
sqrt_lasso = SqrtLasso(alpha=alpha_max).fit(X, y)
1515

@@ -24,7 +24,7 @@ def test_vs_statsmodels():
2424
n_samples, n_features = 50, 10
2525
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
2626

27-
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
27+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
2828
n_alphas = 3
2929
alphas = alpha_max * np.geomspace(1, 1e-2, n_alphas+1)[1:]
3030

@@ -36,9 +36,10 @@ def test_vs_statsmodels():
3636
# fit statsmodels on path
3737
for i in range(n_alphas):
3838
alpha = alphas[i]
39+
# statsmodels solves: ||y - Xw||_2 + alpha * ||w||_1 / sqrt(n_samples)
3940
model = linear_model.OLS(y, X)
4041
model = model.fit_regularized(method='sqrt_lasso', L1_wt=1.,
41-
alpha=n_samples * alpha)
42+
alpha=np.sqrt(n_samples) * alpha)
4243
coefs_statsmodels[i] = model.params
4344

4445
np.testing.assert_almost_equal(coefs_skglm, coefs_statsmodels, decimal=4)
@@ -48,7 +49,7 @@ def test_prox_newton_cp():
4849
n_samples, n_features = 50, 10
4950
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
5051

51-
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
52+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
5253
alpha = alpha_max / 10
5354
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
5455
w, _, _ = _chambolle_pock_sqrt(X, y, alpha, max_iter=1000)

0 commit comments

Comments
 (0)