Skip to content

Commit f2ac94e

Browse files
committed
add UT + rename to QuadraticHessian + rm estimator class
1 parent ec4be21 commit f2ac94e

File tree

5 files changed

+40
-65
lines changed

5 files changed

+40
-65
lines changed

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Datafits
6868
Poisson
6969
Quadratic
7070
QuadraticGroup
71+
QuadraticHessian
7172
QuadraticSVC
7273
WeightedQuadratic
7374

@@ -102,4 +103,4 @@ Experimental
102103
PDCD_WS
103104
Pinball
104105
SqrtQuadratic
105-
SqrtLasso
106+
SqrtLasso

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
22
from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
3-
Cox, WeightedQuadratic, HessianQuadratic,)
3+
Cox, WeightedQuadratic, QuadraticHessian,)
44
from .multi_task import QuadraticMultiTask
55
from .group import QuadraticGroup, LogisticGroup
66

@@ -10,5 +10,5 @@
1010
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
1111
QuadraticMultiTask,
1212
QuadraticGroup, LogisticGroup, WeightedQuadratic,
13-
HessianQuadratic
13+
QuadraticHessian
1414
]

skglm/datafits/single_task.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,16 @@ def intercept_update_step(self, y, Xw):
240240
return np.sum(self.sample_weights * (Xw - y)) / self.sample_weights.sum()
241241

242242

243-
class HessianQuadratic(BaseDatafit):
244-
r"""_summary_
243+
class QuadraticHessian(BaseDatafit):
244+
r"""Quadratic datafit where we pass the Hessian A directly.
245245
246246
The datafit reads:
247247
248-
.. math:: 1 / 2 x^\\top A x + \\langle b, x \\rangle
249-
250-
for A symmetric
248+
.. math:: 1 / 2 x^(\top) A x + \langle b, x \rangle
251249
250+
For a symmetric A. Up to a constant, it is the same as a Quadratic, with
251+
:math:`A = 1 / (n_"samples") X^(\top)X` and :math:`b = - 1 / n_"samples" X^(\top)y`.
252+
When the Hessian is available, this datafit is more efficient than using Quadratic.
252253
"""
253254

254255
def __init__(self):
@@ -264,7 +265,7 @@ def get_lipschitz(self, A, b):
264265
n_features = A.shape[0]
265266
lipschitz = np.zeros(n_features, dtype=A.dtype)
266267
for j in range(n_features):
267-
lipschitz[j] = np.sqrt((A[:, j]**2).sum())
268+
lipschitz[j] = A[j, j]
268269
return lipschitz
269270

270271
def gradient_scalar(self, A, b, w, Ax, j):
@@ -887,8 +888,7 @@ def _A_dot_vec(self, vec):
887888
for idx in range(n_H):
888889
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
889890
size_current_H = current_H_idx.shape[0]
890-
frac_range = np.arange(
891-
size_current_H, dtype=vec.dtype) / size_current_H
891+
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H
892892

893893
sum_vec_H = np.sum(vec[current_H_idx])
894894
out[current_H_idx] = sum_vec_H * frac_range
@@ -903,8 +903,7 @@ def _AT_dot_vec(self, vec):
903903
for idx in range(n_H):
904904
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
905905
size_current_H = current_H_idx.shape[0]
906-
frac_range = np.arange(
907-
size_current_H, dtype=vec.dtype) / size_current_H
906+
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H
908907

909908
weighted_sum_vec_H = vec[current_H_idx] @ frac_range
910909
out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H)

skglm/estimators.py

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from skglm.utils.jit_compilation import compiled_clone
2222
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
2323
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
24-
QuadraticMultiTask, QuadraticGroup, HessianQuadratic)
24+
QuadraticMultiTask, QuadraticGroup, QuadraticHessian)
2525
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2626
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
2727
from skglm.utils.data import grp_converter
@@ -126,8 +126,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
126126
w = np.zeros(n_features + fit_intercept, dtype=X_.dtype)
127127
Xw = np.zeros(n_samples, dtype=X_.dtype)
128128
else: # multitask
129-
w = np.zeros((n_features + fit_intercept,
130-
y.shape[1]), dtype=X_.dtype)
129+
w = np.zeros((n_features + fit_intercept, y.shape[1]), dtype=X_.dtype)
131130
Xw = np.zeros(y.shape, dtype=X_.dtype)
132131

133132
# check consistency of weights for WeightedL1
@@ -450,42 +449,6 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
450449
return solver.path(X, y, datafit, penalty, alphas, coef_init, return_n_iter)
451450

452451

453-
class L1PenalizedQP(BaseEstimator):
454-
def __init__(self, alpha=1., max_iter=50, max_epochs=50_000, p0=10, verbose=0,
455-
tol=1e-4, positive=False, fit_intercept=True, warm_start=False,
456-
ws_strategy="subdiff"):
457-
super().__init__()
458-
self.alpha = alpha
459-
self.tol = tol
460-
self.max_iter = max_iter
461-
self.max_epochs = max_epochs
462-
self.p0 = p0
463-
self.ws_strategy = ws_strategy
464-
self.positive = positive
465-
self.fit_intercept = fit_intercept
466-
self.warm_start = warm_start
467-
self.verbose = verbose
468-
469-
def fit(self, A, b):
470-
"""Fit the model according to the given training data.
471-
472-
Parameters
473-
----------
474-
A : array-like, shape (n_features, n_features)
475-
b : array-like, shape (n_samples,)
476-
477-
Returns
478-
-------
479-
self :
480-
Fitted estimator.
481-
"""
482-
solver = AndersonCD(
483-
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
484-
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
485-
warm_start=self.warm_start, verbose=self.verbose)
486-
return _glm_fit(A, b, self, HessianQuadratic(), L1(self.alpha, self.positive), solver)
487-
488-
489452
class WeightedLasso(LinearModel, RegressorMixin):
490453
r"""WeightedLasso estimator based on Celer solver and primal extrapolation.
491454
@@ -613,8 +576,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
613576
raise ValueError("The number of weights must match the number of \
614577
features. Got %s, expected %s." % (
615578
len(weights), X.shape[1]))
616-
penalty = compiled_clone(WeightedL1(
617-
self.alpha, weights, self.positive))
579+
penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive))
618580
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
619581
solver = AndersonCD(
620582
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
@@ -952,8 +914,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
952914
f"Got {len(self.weights)}, expected {X.shape[1]}."
953915
)
954916
penalty = compiled_clone(
955-
WeightedMCPenalty(self.alpha, self.gamma,
956-
self.weights, self.positive)
917+
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
957918
)
958919
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
959920
solver = AndersonCD(
@@ -1348,8 +1309,7 @@ def fit(self, X, y):
13481309
# copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \
13491310
# 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \
13501311
# estimator_checks.py#L3886
1351-
raise ValueError(
1352-
"requires y to be passed, but the target y is None")
1312+
raise ValueError("requires y to be passed, but the target y is None")
13531313
y = check_array(
13541314
y,
13551315
accept_sparse=False,
@@ -1390,8 +1350,7 @@ def fit(self, X, y):
13901350

13911351
# init solver
13921352
if self.l1_ratio == 0.:
1393-
solver = LBFGS(max_iter=self.max_iter,
1394-
tol=self.tol, verbose=self.verbose)
1353+
solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose)
13951354
else:
13961355
solver = ProxNewton(
13971356
max_iter=self.max_iter, tol=self.tol, verbose=self.verbose,
@@ -1529,8 +1488,7 @@ def fit(self, X, Y):
15291488
if not self.warm_start or not hasattr(self, "coef_"):
15301489
self.coef_ = None
15311490

1532-
datafit_jit = compiled_clone(
1533-
QuadraticMultiTask(), X.dtype == np.float32)
1491+
datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32)
15341492
penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32)
15351493

15361494
solver = MultiTaskBCD(
@@ -1710,8 +1668,7 @@ def fit(self, X, y):
17101668
"The total number of group members must equal the number of features. "
17111669
f"Got {n_features}, expected {X.shape[1]}.")
17121670

1713-
weights = np.ones(
1714-
len(group_sizes)) if self.weights is None else self.weights
1671+
weights = np.ones(len(group_sizes)) if self.weights is None else self.weights
17151672
group_penalty = WeightedGroupL2(alpha=self.alpha, grp_ptr=grp_ptr,
17161673
grp_indices=grp_indices, weights=weights,
17171674
positive=self.positive)

skglm/tests/test_datafits.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.testing import assert_allclose, assert_array_less
77

88
from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic,
9-
Quadratic,)
9+
Quadratic, QuadraticHessian)
1010
from skglm.penalties import L1, WeightedL1
1111
from skglm.solvers import AndersonCD, ProxNewton
1212
from skglm import GeneralizedLinearEstimator
@@ -219,5 +219,23 @@ def test_sample_weights(fit_intercept):
219219
# np.testing.assert_equal(n_iter, n_iter_overs)
220220

221221

222+
def test_HessianQuadratic():
223+
n_samples = 20
224+
n_features = 10
225+
X, y, _ = make_correlated_data(
226+
n_samples=n_samples, n_features=n_features, random_state=0)
227+
A = X.T @ X / n_samples
228+
b = -X.T @ y / n_samples
229+
alpha = np.max(np.abs(b)) / 10
230+
231+
pen = L1(alpha)
232+
solv = AndersonCD(warm_start=False, verbose=2, fit_intercept=False)
233+
lasso = GeneralizedLinearEstimator(Quadratic(), pen, solv).fit(X, y)
234+
qpl1 = GeneralizedLinearEstimator(QuadraticHessian(), pen, solv).fit(A, b)
235+
236+
np.testing.assert_allclose(lasso.coef_, qpl1.coef_)
237+
# check that it's not just because we got alpha too high and thus 0 coef
238+
np.testing.assert_array_less(0.1, np.max(np.abs(qpl1.coef_)))
239+
222240
if __name__ == '__main__':
223241
pass

0 commit comments

Comments
 (0)